Session CryptHOL

Theory Misc_CryptHOL

(* Title: Misc_CryptHOL.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Miscellaneous library additions›

theory Misc_CryptHOL imports 
  Probabilistic_While.While_SPMF
  "HOL-Library.Rewrite"
  "HOL-Library.Simps_Case_Conv"
  "HOL-Library.Type_Length"
  "HOL-Eisbach.Eisbach"
  Coinductive.TLList
  Monad_Normalisation.Monad_Normalisation
  Monomorphic_Monad.Monomorphic_Monad
  Applicative_Lifting.Applicative
begin

hide_const (open) Henstock_Kurzweil_Integration.negligible

declare eq_on_def [simp del]

subsection ‹HOL›

lemma asm_rl_conv: "(PROP P  PROP P)  Trueprop True"
by(rule equal_intr_rule) iprover+

named_theorems if_distribs "Distributivity theorems for If"

lemma if_mono_cong: "b  x  x'; ¬ b  y  y'   If b x y  If b x' y'"
by simp

lemma if_cong_then: " b = b'; b'  t = t'; e = e'   If b t e = If b' t' e'"
by simp

lemma if_False_eq: " b  False; e = e'   If b t e = e'"
by auto

lemma imp_OO_imp [simp]: "(⟶) OO (⟶) = (⟶)"
by auto

lemma inj_on_fun_updD: " inj_on (f(x := y)) A; x  A   inj_on f A"
by(auto simp add: inj_on_def split: if_split_asm)

lemma disjoint_notin1: " A  B = {}; x  B   x  A" by auto

lemma Least_le_Least:
  fixes x :: "'a :: wellorder"
  assumes "Q x"
  and Q: "x. Q x  yx. P y"
  shows "Least P  Least Q"
proof -
  obtain f :: "'a  'a" where "a. ¬ Q a  f a  a  P (f a)" using Q by moura
  moreover have "Q (Least Q)" using Q x by(rule LeastI)
  ultimately show ?thesis by (metis (full_types) le_cases le_less less_le_trans not_less_Least)
qed

lemma is_empty_image [simp]: "Set.is_empty (f ` A) = Set.is_empty A"
  by(auto simp add: Set.is_empty_def)

subsection ‹Relations›

inductive Imagep :: "('a  'b  bool)  ('a  bool)  'b  bool"
  for R P
where ImagepI: " P x; R x y   Imagep R P y"

lemma r_r_into_tranclp: " r x y; r y z   r^++ x z"
by(rule tranclp.trancl_into_trancl)(rule tranclp.r_into_trancl)

lemma transp_tranclp_id:
  assumes "transp R"
  shows "tranclp R = R"
proof(intro ext iffI)
  fix x y
  assume "R^++ x y"
  thus "R x y" by induction(blast dest: transpD[OF assms])+
qed simp

lemma transp_inv_image: "transp r  transp (λx y. r (f x) (f y))"
using trans_inv_image[where r="{(x, y). r x y}" and f = f]
by(simp add: transp_trans inv_image_def)

lemma Domainp_conversep: "Domainp R¯¯ = Rangep R"
by(auto)

lemma bi_unique_rel_set_bij_betw:
  assumes unique: "bi_unique R"
  and rel: "rel_set R A B"
  shows "f. bij_betw f A B  (xA. R x (f x))"
proof -
  from assms obtain f where f: "x. x  A  R x (f x)" and B: "x. x  A  f x  B"
    apply(atomize_elim)
    apply(fold all_conj_distrib)
    apply(subst choice_iff[symmetric])
    apply(auto dest: rel_setD1)
    done
  have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
  moreover have "f ` A = B" using rel
    by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
  ultimately have "bij_betw f A B" unfolding bij_betw_def ..
  thus ?thesis using f by blast
qed

definition restrict_relp :: "('a  'b  bool)  ('a  bool)  ('b  bool)  'a  'b  bool"
  ("_  (_  _)" [53, 54, 54] 53)
where "restrict_relp R P Q = (λx y. R x y  P x  Q y)"

lemma restrict_relp_apply [simp]: "(R  P  Q) x y  R x y  P x  Q y"
by(simp add: restrict_relp_def)

lemma restrict_relpI [intro?]: " R x y; P x; Q y   (R  P  Q) x y"
by(simp add: restrict_relp_def)

lemma restrict_relpE [elim?, cases pred]:
  assumes "(R  P  Q) x y"
  obtains (restrict_relp) "R x y" "P x" "Q y"
using assms by(simp add: restrict_relp_def)

lemma conversep_restrict_relp [simp]: "(R  P  Q)¯¯ = R¯¯  Q  P"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_restrict_relp [simp]: "R  P  Q  P'  Q' = R  inf P P'  inf Q Q'"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_cong:
  " P = P'; Q = Q'; x y.  P x; Q y   R x y = R' x y   R  P  Q = R'  P'  Q'"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_cong_simp:
  " P = P'; Q = Q'; x y. P x =simp=> Q y =simp=> R x y = R' x y   R  P  Q = R'  P'  Q'"
by(rule restrict_relp_cong; simp add: simp_implies_def)

lemma restrict_relp_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((A ===> B ===> (=)) ===> (A ===> (=)) ===> (B ===> (=)) ===> A ===> B ===> (=)) restrict_relp restrict_relp"
unfolding restrict_relp_def[abs_def] by transfer_prover

lemma restrict_relp_mono: " R  R'; P  P'; Q  Q'   R  P  Q  R'  P'  Q'"
by(simp add: le_fun_def)

lemma restrict_relp_mono': 
  " (R  P  Q) x y;  R x y; P x; Q y   R' x y &&& P' x &&& Q' y 
   (R'  P'  Q') x y"
by(auto dest: conjunctionD1 conjunctionD2)

lemma restrict_relp_DomainpD: "Domainp (R  P  Q) x  Domainp R x  P x"
by(auto simp add: Domainp.simps)

lemma restrict_relp_True: "R  (λ_. True)  (λ_. True) = R"
by(simp add: fun_eq_iff)

lemma restrict_relp_False1: "R  (λ_. False)  Q = bot"
by(simp add: fun_eq_iff)

lemma restrict_relp_False2: "R  P  (λ_. False) = bot"
by(simp add: fun_eq_iff)

definition rel_prod2 :: "('a  'b  bool)  'a  ('c × 'b)  bool"
where "rel_prod2 R a = (λ(c, b). R a b)"

lemma rel_prod2_simps [simp]: "rel_prod2 R a (c, b)  R a b"
by(simp add: rel_prod2_def)

lemma restrict_rel_prod:
  "rel_prod (R  I1  I2) (S  I1'  I2') = rel_prod R S  pred_prod I1 I1'  pred_prod I2 I2'"
by(auto simp add: fun_eq_iff)

lemma restrict_rel_prod1:
  "rel_prod (R  I1  I2) S = rel_prod R S  pred_prod I1 (λ_. True)  pred_prod I2 (λ_. True)"
by(simp add: restrict_rel_prod[symmetric] restrict_relp_True)

lemma restrict_rel_prod2:
  "rel_prod R (S  I1  I2) = rel_prod R S  pred_prod (λ_. True) I1  pred_prod (λ_. True) I2"
by(simp add: restrict_rel_prod[symmetric] restrict_relp_True)

consts relcompp_witness :: "('a  'b  bool)  ('b  'c  bool)  'a × 'c  'b"
specification (relcompp_witness)
  relcompp_witness1: "(A OO B) (fst xy) (snd xy)  A (fst xy) (relcompp_witness A B xy)"
  relcompp_witness2: "(A OO B) (fst xy) (snd xy)  B (relcompp_witness A B xy) (snd xy)"
  apply(fold all_conj_distrib)
  apply(rule choice allI)+
  by(auto intro: choice allI)

lemmas relcompp_witness[of _ _ "(x, y)" for x y, simplified] = relcompp_witness1 relcompp_witness2

hide_fact (open) relcompp_witness1 relcompp_witness2

lemma relcompp_witness_eq [simp]: "relcompp_witness (=) (=) (x, x) = x"
  using relcompp_witness(1)[of "(=)" "(=)" x x] by(simp add: eq_OO)

subsection ‹Pairs›

lemma split_apfst [simp]: "case_prod h (apfst f xy) = case_prod (h  f) xy"
by(cases xy) simp

definition corec_prod :: "('s  'a)  ('s  'b)  's  'a × 'b"
where "corec_prod f g = (λs. (f s, g s))"

lemma corec_prod_apply: "corec_prod f g s = (f s, g s)"
by(simp add: corec_prod_def)

lemma corec_prod_sel [simp]:
  shows fst_corec_prod: "fst (corec_prod f g s) = f s"
  and snd_corec_prod: "snd (corec_prod f g s) = g s"
by(simp_all add: corec_prod_apply)

lemma apfst_corec_prod [simp]: "apfst h (corec_prod f g s) = corec_prod (h  f) g s"
by(simp add: corec_prod_apply)

lemma apsnd_corec_prod [simp]: "apsnd h (corec_prod f g s) = corec_prod f (h  g) s"
by(simp add: corec_prod_apply)

lemma map_corec_prod [simp]: "map_prod f g (corec_prod h k s) = corec_prod (f  h) (g  k) s"
by(simp add: corec_prod_apply)

lemma split_corec_prod [simp]: "case_prod h (corec_prod f g s) = h (f s) (g s)"
by(simp add: corec_prod_apply)

lemma Pair_fst_Unity: "(fst x, ()) = x"
  by(cases x) simp

definition rprodl :: "('a × 'b) × 'c  'a × ('b × 'c)" where "rprodl = (λ((a, b), c). (a, (b, c)))"

lemma rprodl_simps [simp]: "rprodl ((a, b), c) = (a, (b, c))"
  by(simp add: rprodl_def)

lemma rprodl_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_prod (rel_prod A B) C ===> rel_prod A (rel_prod B C)) rprodl rprodl"
  unfolding rprodl_def by transfer_prover

definition lprodr :: "'a × ('b × 'c)  ('a × 'b) × 'c" where "lprodr = (λ(a, b, c). ((a, b), c))"

lemma lprodr_simps [simp]: "lprodr (a, b, c) = ((a, b), c)"
  by(simp add: lprodr_def)

lemma lprodr_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_prod A (rel_prod B C) ===> rel_prod (rel_prod A B) C) lprodr lprodr"
  unfolding lprodr_def by transfer_prover

lemma lprodr_inverse [simp]: "rprodl (lprodr x) = x"
  by(cases x) auto

lemma rprodl_inverse [simp]: "lprodr (rprodl x) = x"
  by(cases x) auto

lemma pred_prod_mono' [mono]:
  "pred_prod A B xy  pred_prod A' B' xy"
  if "x. A x  A' x" "y. B y  B' y"
  using that by(cases xy) auto

fun rel_witness_prod :: "('a × 'b) × ('c × 'd)  (('a × 'c) × ('b × 'd))" where
  "rel_witness_prod ((a, b), (c, d)) = ((a, c), (b, d))"

subsection ‹Sums›

lemma islE:
  assumes "isl x"
  obtains l where "x = Inl l"
using assms by(cases x) auto

lemma Inl_in_Plus [simp]: "Inl x  A <+> B  x  A"
by auto

lemma Inr_in_Plus [simp]: "Inr x  A <+> B  x  B"
by auto

lemma Inl_eq_map_sum_iff: "Inl x = map_sum f g y  (z. y = Inl z  x = f z)"
by(cases y) auto

lemma Inr_eq_map_sum_iff: "Inr x = map_sum f g y  (z. y = Inr z  x = g z)"
by(cases y) auto

lemma inj_on_map_sum [simp]:
  " inj_on f A; inj_on g B   inj_on (map_sum f g) (A <+> B)"
proof(rule inj_onI, goal_cases)
  case (1 x y)
  then show ?case by(cases x; cases y; auto simp add: inj_on_def)
qed

lemma inv_into_map_sum:
  "inv_into (A <+> B) (map_sum f g) x = map_sum (inv_into A f) (inv_into B g) x"
  if "x  f ` A <+> g ` B" "inj_on f A" "inj_on g B"
  using that by(cases rule: PlusE[consumes 1])(auto simp add: inv_into_f_eq f_inv_into_f)

fun rsuml :: "('a + 'b) + 'c  'a + ('b + 'c)" where
  "rsuml (Inl (Inl a)) = Inl a"
| "rsuml (Inl (Inr b)) = Inr (Inl b)"
| "rsuml (Inr c) = Inr (Inr c)"

fun lsumr :: "'a + ('b + 'c)  ('a + 'b) + 'c" where
  "lsumr (Inl a) = Inl (Inl a)"
| "lsumr (Inr (Inl b)) = Inl (Inr b)"
| "lsumr (Inr (Inr c)) = Inr c"

lemma rsuml_lsumr [simp]: "rsuml (lsumr x) = x"
  by(cases x rule: lsumr.cases) simp_all

lemma lsumr_rsuml [simp]: "lsumr (rsuml x) = x"
  by(cases x rule: rsuml.cases) simp_all

subsection ‹Option›

declare is_none_bind [simp]

lemma case_option_collapse: "case_option x (λ_. x) y = x"
by(simp split: option.split)

lemma indicator_single_Some: "indicator {Some x} (Some y) = indicator {x} y"
by(simp split: split_indicator)

subsubsection ‹Predicator and relator›

lemma option_pred_mono_strong:
  " pred_option P x; a.  a  set_option x; P a   P' a   pred_option P' x"
by(fact option.pred_mono_strong)

lemma option_pred_map [simp]: "pred_option P (map_option f x) = pred_option (P  f) x"
by(fact option.pred_map)

lemma option_pred_o_map [simp]: "pred_option P  map_option f = pred_option (P  f)"
by(simp add: fun_eq_iff)

lemma option_pred_bind [simp]: "pred_option P (Option.bind x f) = pred_option (pred_option P  f) x"
by(simp add: pred_option_def)

lemma pred_option_conj [simp]:
  "pred_option (λx. P x  Q x) = (λx. pred_option P x  pred_option Q x)"
by(auto simp add: pred_option_def)

lemma pred_option_top [simp]:
  "pred_option (λ_. True) = (λ_. True)"
by(fact option.pred_True)

lemma rel_option_restrict_relpI [intro?]:
  " rel_option R x y; pred_option P x; pred_option Q y   rel_option (R  P  Q) x y"
by(erule option.rel_mono_strong) simp

lemma rel_option_restrict_relpE [elim?]:
  assumes "rel_option (R  P  Q) x y"
  obtains "rel_option R x y" "pred_option P x" "pred_option Q y"
proof
  show "rel_option R x y" using assms by(auto elim!: option.rel_mono_strong)
  have "pred_option (Domainp (R  P  Q)) x" using assms by(fold option.Domainp_rel) blast
  then show "pred_option P x" by(rule option_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_option (Domainp (R  P  Q)¯¯) y" using assms
    by(fold option.Domainp_rel)(auto simp only: option.rel_conversep Domainp_conversep)
  then show "pred_option Q y" by(rule option_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_option_restrict_relp_iff:
  "rel_option (R  P  Q) x y  rel_option R x y  pred_option P x  pred_option Q y"
by(blast intro: rel_option_restrict_relpI elim: rel_option_restrict_relpE)

lemma option_rel_map_restrict_relp:
  shows option_rel_map_restrict_relp1:
  "rel_option (R  P  Q) (map_option f x) = rel_option (R  f  P  f  Q) x"
  and option_rel_map_restrict_relp2:
  "rel_option (R  P  Q) x (map_option g y) = rel_option ((λx. R x  g)  P  Q  g) x y"
by(simp_all add: option.rel_map restrict_relp_def fun_eq_iff)

fun rel_witness_option :: "'a option × 'b option  ('a × 'b) option" where
  "rel_witness_option (Some x, Some y) = Some (x, y)"
| "rel_witness_option (None, None) = None"
| "rel_witness_option _ = None" ― ‹Just to make the definition complete›

lemma rel_witness_option:
  shows set_rel_witness_option: " rel_option A x y; (a, b)  set_option (rel_witness_option (x, y))   A a b"
    and map1_rel_witness_option: "rel_option A x y  map_option fst (rel_witness_option (x, y)) = x"
    and map2_rel_witness_option: "rel_option A x y  map_option snd (rel_witness_option (x, y)) = y"
  by(cases "(x, y)" rule: rel_witness_option.cases; simp; fail)+

lemma rel_witness_option1:
  assumes "rel_option A x y"
  shows "rel_option (λa (a', b). a = a'  A a' b) x (rel_witness_option (x, y))"
  using map1_rel_witness_option[OF assms, symmetric]
  unfolding option.rel_eq[symmetric] option.rel_map
  by(rule option.rel_mono_strong)(auto intro: set_rel_witness_option[OF assms])

lemma rel_witness_option2:
  assumes "rel_option A x y"
  shows "rel_option (λ(a, b') b. b = b'  A a b') (rel_witness_option (x, y)) y"
  using map2_rel_witness_option[OF assms]
  unfolding option.rel_eq[symmetric] option.rel_map
  by(rule option.rel_mono_strong)(auto intro: set_rel_witness_option[OF assms])

subsubsection ‹Orders on option›

abbreviation le_option :: "'a option  'a option  bool"
where "le_option  ord_option (=)"

lemma le_option_bind_mono:
  " le_option x y; a. a  set_option x  le_option (f a) (g a) 
   le_option (Option.bind x f) (Option.bind y g)"
by(cases x) simp_all

lemma le_option_refl [simp]: "le_option x x"
by(cases x) simp_all


lemma le_option_conv_option_ord: "le_option = option_ord"
by(auto simp add: fun_eq_iff flat_ord_def elim: ord_option.cases)

definition pcr_Some :: "('a  'b  bool)  'a  'b option  bool"
where "pcr_Some R x y  (z. y = Some z  R x z)"

lemma pcr_Some_simps [simp]: "pcr_Some R x (Some y)  R x y"
by(simp add: pcr_Some_def)

lemma pcr_SomeE [cases pred]:
  assumes "pcr_Some R x y"
  obtains (pcr_Some) z where "y = Some z" "R x z"
using assms by(auto simp add: pcr_Some_def)

subsubsection ‹Filter for option›

fun filter_option :: "('a  bool)  'a option  'a option"
where
  "filter_option P None = None"
| "filter_option P (Some x) = (if P x then Some x else None)"

lemma set_filter_option [simp]: "set_option (filter_option P x) = {y  set_option x. P y}"
by(cases x) auto

lemma filter_map_option: "filter_option P (map_option f x) = map_option f (filter_option (P  f) x)"
by(cases x) simp_all

lemma is_none_filter_option [simp]: "Option.is_none (filter_option P x)  Option.is_none x  ¬ P (the x)"
by(cases x) simp_all

lemma filter_option_eq_Some_iff [simp]: "filter_option P x = Some y  x = Some y  P y"
by(cases x) auto

lemma Some_eq_filter_option_iff [simp]: "Some y = filter_option P x  x = Some y  P y"
by(cases x) auto

lemma filter_conv_bind_option: "filter_option P x = Option.bind x (λy. if P y then Some y else None)"
by(cases x) simp_all

subsubsection ‹Assert for option›

primrec assert_option :: "bool  unit option" where
  "assert_option True = Some ()"
| "assert_option False = None"

lemma set_assert_option_conv: "set_option (assert_option b) = (if b then {()} else {})"
by(simp)

lemma in_set_assert_option [simp]: "x  set_option (assert_option b)  b"
by(cases b) simp_all


subsubsection ‹Join on options›

definition join_option :: "'a option option  'a option"
where "join_option x = (case x of Some y  y | None  None)"

simps_of_case join_simps [simp, code]: join_option_def

lemma set_join_option [simp]: "set_option (join_option x) = (set_option ` set_option x)"
by(cases x)(simp_all)

lemma in_set_join_option: "x  set_option (join_option (Some (Some x)))"
by simp

lemma map_join_option: "map_option f (join_option x) = join_option (map_option (map_option f) x)"
by(cases x) simp_all

lemma bind_conv_join_option: "Option.bind x f = join_option (map_option f x)"
by(cases x) simp_all

lemma join_conv_bind_option: "join_option x = Option.bind x id"
by(cases x) simp_all

lemma join_option_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_option (rel_option R) ===> rel_option R) join_option join_option"
unfolding join_conv_bind_option[abs_def] by transfer_prover

lemma join_option_eq_Some [simp]: "join_option x = Some y  x = Some (Some y)"
by(cases x) simp_all

lemma Some_eq_join_option [simp]: "Some y = join_option x  x = Some (Some y)"
by(cases x) auto

lemma join_option_eq_None: "join_option x = None  x = None  x = Some None"
by(cases x) simp_all

lemma None_eq_join_option: "None = join_option x  x = None  x = Some None"
by(cases x) auto

subsubsection ‹Zip on options›

function zip_option :: "'a option  'b option  ('a × 'b) option"
where
  "zip_option (Some x) (Some y) = Some (x, y)"
| "zip_option _ None = None"
| "zip_option None _ = None"
by pat_completeness auto
termination by lexicographic_order

lemma zip_option_eq_Some_iff [iff]:
  "zip_option x y = Some (a, b)  x = Some a  y = Some b"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma set_zip_option [simp]:
  "set_option (zip_option x y) = set_option x × set_option y"
by auto

lemma zip_map_option1: "zip_option (map_option f x) y = map_option (apfst f) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma zip_map_option2: "zip_option x (map_option g y) = map_option (apsnd g) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma map_zip_option:
  "map_option (map_prod f g) (zip_option x y) = zip_option (map_option f x) (map_option g y)"
by(simp add: zip_map_option1 zip_map_option2 option.map_comp apfst_def apsnd_def o_def prod.map_comp)

lemma zip_conv_bind_option:
  "zip_option x y = Option.bind x (λx. Option.bind y (λy. Some (x, y)))"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma zip_option_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_option R ===> rel_option Q ===> rel_option (rel_prod R Q)) zip_option zip_option"
unfolding zip_conv_bind_option[abs_def] by transfer_prover

lemma rel_option_eqI [simp]: "rel_option (=) x x"
by(simp add: option.rel_eq)

subsubsection ‹Binary supremum on @{typ "'a option"}

primrec sup_option :: "'a option  'a option  'a option"
where
  "sup_option x None = x"
| "sup_option x (Some y) = (Some y)"

lemma sup_option_idem [simp]: "sup_option x x = x"
by(cases x) simp_all

lemma sup_option_assoc: "sup_option (sup_option x y) z = sup_option x (sup_option y z)"
by(cases z) simp_all

lemma sup_option_left_idem: "sup_option x (sup_option x y) = sup_option x y"
by(rewrite sup_option_assoc[symmetric])(simp)

lemmas sup_option_ai = sup_option_assoc sup_option_left_idem

lemma sup_option_None [simp]: "sup_option None y = y"
by(cases y) simp_all

subsubsection ‹Restriction on @{typ "'a option"}

primrec (transfer) enforce_option :: "('a  bool)  'a option  'a option" where
  "enforce_option P (Some x) = (if P x then Some x else None)"
| "enforce_option P None = None"

lemma set_enforce_option [simp]: "set_option (enforce_option P x) = {a  set_option x. P a}"
  by(cases x) auto

lemma enforce_map_option: "enforce_option P (map_option f x) = map_option f (enforce_option (P  f) x)"
  by(cases x) auto

lemma enforce_bind_option [simp]:
  "enforce_option P (Option.bind x f) = Option.bind x (enforce_option P  f)"
  by(cases x) auto

lemma enforce_option_alt_def:
  "enforce_option P x = Option.bind x (λa. Option.bind (assert_option (P a)) (λ_ :: unit. Some a))"
  by(cases x) simp_all

lemma enforce_option_eq_None_iff [simp]:
  "enforce_option P x = None  (a. x = Some a  ¬ P a)"
  by(cases x) auto

lemma enforce_option_eq_Some_iff [simp]:
  "enforce_option P x = Some y  x = Some y  P y"
  by(cases x) auto

lemma Some_eq_enforce_option_iff [simp]:
  "Some y = enforce_option P x  x = Some y  P y"
  by(cases x) auto

lemma enforce_option_top [simp]: "enforce_option  = id"
  by(rule ext; rename_tac x; case_tac x; simp)

lemma enforce_option_K_True [simp]: "enforce_option (λ_. True) x = x"
  by(cases x) simp_all

lemma enforce_option_bot [simp]: "enforce_option  = (λ_. None)"
  by(simp add: fun_eq_iff)

lemma enforce_option_K_False [simp]: "enforce_option (λ_. False) x = None"
  by simp

lemma enforce_pred_id_option: "pred_option P x  enforce_option P x = x"
  by(cases x) auto

subsubsection ‹Maps›

lemma map_add_apply: "(m1 ++ m2) x = sup_option (m1 x) (m2 x)"
by(simp add: map_add_def split: option.split)

lemma map_le_map_upd2: " f m g; y'. f x = Some y'  y' = y   f m g(x  y)"
by(cases "x  dom f")(auto simp add: map_le_def Ball_def)

lemma eq_None_iff_not_dom: "f x = None  x  dom f"
by auto

lemma card_ran_le_dom: "finite (dom m)  card (ran m)  card (dom m)"
by(simp add: ran_alt_def card_image_le)

lemma dom_subset_ran_iff:
  assumes "finite (ran m)"
  shows "dom m  ran m  dom m = ran m"
proof
  assume le: "dom m  ran m"
  then have "card (dom m)  card (ran m)" by(simp add: card_mono assms)
  moreover have "card (ran m)  card (dom m)" by(simp add: finite_subset[OF le assms] card_ran_le_dom)
  ultimately show "dom m = ran m" using card_subset_eq[OF assms le] by simp
qed simp

text ‹
  We need a polymorphic constant for the empty map such that transfer_prover›
  can use a custom transfer rule for @{const Map.empty}
definition Map_empty where [simp]: "Map_empty  Map.empty"

lemma map_le_Some1D: " m m m'; m x = Some y   m' x = Some y"
by(auto simp add: map_le_def Ball_def)

lemma map_le_fun_upd2: " f m g; x  dom f   f m g(x := y)"
by(auto simp add: map_le_def)

lemma map_eqI: "xdom m  dom m'. m x = m' x  m = m'"
by(auto simp add: fun_eq_iff domIff intro: option.expand)


subsection ‹Countable›

lemma countable_lfp:
  assumes step: "Y. countable Y  countable (F Y)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F)"
by(subst sup_continuous_lfp[OF cont])(simp add: countable_funpow[OF step])

lemma countable_lfp_apply:
  assumes step: "Y x. (x. countable (Y x))  countable (F Y x)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F x)"
proof -
  { fix n
    have "x. countable ((F ^^ n) bot x)"
      by(induct n)(auto intro: step) }
  thus ?thesis using cont by(simp add: sup_continuous_lfp)
qed


subsection ‹ Extended naturals ›

lemma idiff_enat_eq_enat_iff: "x - enat n = enat m  (k. x = enat k  k - n = m)"
  by (cases x) simp_all

lemma eSuc_SUP: "A  {}  eSuc ( (f ` A)) = (xA. eSuc (f x))"
  by (subst eSuc_Sup) (simp_all add: image_comp)

lemma ereal_of_enat_1: "ereal_of_enat 1 = ereal 1"
  by (simp add: one_enat_def)

lemma ennreal_real_conv_ennreal_of_enat: "ennreal (real n) = ennreal_of_enat n"
  by (simp add: ennreal_of_nat_eq_real_of_nat)

lemma enat_add_sub_same2: "b    a + b - b = (a :: enat)"
  by (cases a; cases b) simp_all

lemma enat_sub_add: "y  x  x - y + z = x + z - (y :: enat)"
  by (cases x; cases y; cases z) simp_all

lemma SUP_enat_eq_0_iff [simp]: " (f ` A) = (0 :: enat)  (xA. f x = 0)"
  by (simp add: bot_enat_def [symmetric])

lemma SUP_enat_add_left:
  assumes "I  {}"
  shows "(SUP iI. f i + c :: enat) = (SUP iI. f i) + c" (is "?lhs = ?rhs")
proof(cases "c", rule antisym)
  case (enat n)
  show "?lhs  ?rhs" by(auto 4 3 intro: SUP_upper intro: SUP_least)
  have "(SUP iI. f i)  ?lhs - c" using enat 
    by(auto simp add: enat_add_sub_same2 intro!: SUP_least order_trans[OF _ SUP_upper[THEN enat_minus_mono1]])
  note add_right_mono[OF this, of c]
  also have " + c  ?lhs" using assms
    by(subst enat_sub_add)(auto intro: SUP_upper2 simp add: enat_add_sub_same2 enat)
  finally show "?rhs  ?lhs" .
qed(simp add: assms SUP_constant)

lemma SUP_enat_add_right:
  assumes "I  {}"
  shows "(SUP iI. c + f i :: enat) = c + (SUP iI. f i)"
using SUP_enat_add_left[OF assms, of f c]
by(simp add: add.commute)

lemma iadd_SUP_le_iff: "n + (SUP xA. f x :: enat)  y  (if A = {} then n  y else xA. n + f x  y)"
by(simp add: bot_enat_def SUP_enat_add_right[symmetric] SUP_le_iff)

lemma SUP_iadd_le_iff: "(SUP xA. f x :: enat) + n  y  (if A = {} then n  y else xA. f x + n  y)"
using iadd_SUP_le_iff[of n f A y] by(simp add: add.commute)


subsection ‹Extended non-negative reals›

lemma (in finite_measure) nn_integral_indicator_neq_infty: 
  "f -` A  sets M  (+ x. indicator A (f x) M)  "
unfolding ennreal_indicator[symmetric]
apply(rule integrableD)
apply(rule integrable_const_bound[where B=1])
apply(simp_all add: indicator_vimage[symmetric])
done

lemma (in finite_measure) nn_integral_indicator_neq_top: 
  "f -` A  sets M  (+ x. indicator A (f x) M)  "
by(drule nn_integral_indicator_neq_infty) simp

lemma nn_integral_indicator_map:
  assumes [measurable]: "f  measurable M N" "{xspace N. P x}  sets N"
  shows "(+x. indicator {xspace N. P x} (f x) M) = emeasure M {xspace M. P (f x)}"
  using assms(1)[THEN measurable_space] 
  by (subst nn_integral_indicator[symmetric])
     (auto intro!: nn_integral_cong split: split_indicator simp del: nn_integral_indicator)


subsection ‹BNF material›

lemma transp_rel_fun: " is_equality Q; transp R   transp (rel_fun Q R)"
by(rule transpI)(auto dest: transpD rel_funD simp add: is_equality_def)

lemma rel_fun_inf: "inf (rel_fun Q R) (rel_fun Q R') = rel_fun Q (inf R R')"
by(rule antisym)(auto elim: rel_fun_mono dest: rel_funD)

lemma reflp_fun1: includes lifting_syntax shows " is_equality A; reflp B   reflp (A ===> B)"
by(simp add: reflp_def rel_fun_def is_equality_def)

lemma type_copy_id': "type_definition (λx. x) (λx. x) UNIV"
by unfold_locales simp_all

lemma type_copy_id: "type_definition id id UNIV"
by(simp add: id_def type_copy_id')

lemma GrpE [cases pred]:
  assumes "BNF_Def.Grp A f x y"
  obtains (Grp) "y = f x" "x  A"
using assms
by(simp add: Grp_def)

lemma rel_fun_Grp_copy_Abs:
  includes lifting_syntax
  assumes "type_definition Rep Abs A"
  shows "rel_fun (BNF_Def.Grp A Abs) (BNF_Def.Grp B g) = BNF_Def.Grp {f. f ` A  B} (Rep ---> g)"
proof -
  interpret type_definition Rep Abs A by fact
  show ?thesis
    by(auto simp add: rel_fun_def Grp_def fun_eq_iff Abs_inverse Rep_inverse intro!: Rep)
qed

lemma rel_set_Grp:
  "rel_set (BNF_Def.Grp A f) = BNF_Def.Grp {B. B  A} (image f)"
by(auto simp add: rel_set_def BNF_Def.Grp_def fun_eq_iff)

lemma rel_set_comp_Grp:
  "rel_set R = (BNF_Def.Grp {x. x  {(x, y). R x y}} ((`) fst))¯¯ OO BNF_Def.Grp {x. x  {(x, y). R x y}} ((`) snd)"
apply(auto 4 4 del: ext intro!: ext simp add: BNF_Def.Grp_def intro!: rel_setI intro: rev_bexI)
apply(simp add: relcompp_apply)
subgoal for A B
  apply(rule exI[where x="A × B  {(x, y). R x y}"])
  apply(auto 4 3 dest: rel_setD1 rel_setD2 intro: rev_image_eqI)
  done
done

lemma Domainp_Grp: "Domainp (BNF_Def.Grp A f) = (λx. x  A)"
by(auto simp add: fun_eq_iff Grp_def)

lemma pred_prod_conj [simp]:
  shows pred_prod_conj1: "P Q R. pred_prod (λx. P x  Q x) R = (λx. pred_prod P R x  pred_prod Q R x)"
  and pred_prod_conj2: "P Q R. pred_prod P (λx. Q x  R x) = (λx. pred_prod P Q x  pred_prod P R x)"
by(auto simp add: pred_prod.simps)

lemma pred_sum_conj [simp]:
  shows pred_sum_conj1: "P Q R. pred_sum (λx. P x  Q x) R = (λx. pred_sum P R x  pred_sum Q R x)"
  and pred_sum_conj2: "P Q R. pred_sum P (λx. Q x  R x) = (λx. pred_sum P Q x  pred_sum P R x)"
by(auto simp add: pred_sum.simps fun_eq_iff)

lemma pred_list_conj [simp]: "list_all (λx. P x  Q x) = (λx. list_all P x  list_all Q x)"
by(auto simp add: list_all_def)

lemma pred_prod_top [simp]:
  "pred_prod (λ_. True) (λ_. True) = (λ_. True)"
by(simp add: pred_prod.simps fun_eq_iff)

lemma rel_fun_conversep: includes lifting_syntax shows
  "(A^--1 ===> B^--1) = (A ===> B)^--1"
by(auto simp add: rel_fun_def fun_eq_iff)

lemma left_unique_Grp [iff]:
  "left_unique (BNF_Def.Grp A f)  inj_on f A"
unfolding Grp_def left_unique_def by(auto simp add: inj_on_def)

lemma right_unique_Grp [simp, intro!]: "right_unique (BNF_Def.Grp A f)"
by(simp add: Grp_def right_unique_def)

lemma bi_unique_Grp [iff]:
  "bi_unique (BNF_Def.Grp A f)  inj_on f A"
by(simp add: bi_unique_alt_def)

lemma left_total_Grp [iff]:
  "left_total (BNF_Def.Grp A f)  A = UNIV"
by(auto simp add: left_total_def Grp_def)

lemma right_total_Grp [iff]:
  "right_total (BNF_Def.Grp A f)  f ` A = UNIV"
by(auto simp add: right_total_def BNF_Def.Grp_def image_def)

lemma bi_total_Grp [iff]:
  "bi_total (BNF_Def.Grp A f)  A = UNIV  surj f"
by(auto simp add: bi_total_alt_def)

lemma left_unique_vimage2p [simp]:
  " left_unique P; inj f   left_unique (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro left_unique_OO) simp_all

lemma right_unique_vimage2p [simp]:
  " right_unique P; inj g   right_unique (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro right_unique_OO) simp_all

lemma bi_unique_vimage2p [simp]:
  " bi_unique P; inj f; inj g   bi_unique (BNF_Def.vimage2p f g P)"
unfolding bi_unique_alt_def by simp

lemma left_total_vimage2p [simp]:
  " left_total P; surj g   left_total (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro left_total_OO) simp_all

lemma right_total_vimage2p [simp]:
  " right_total P; surj f   right_total (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro right_total_OO) simp_all

lemma bi_total_vimage2p [simp]:
  " bi_total P; surj f; surj g   bi_total (BNF_Def.vimage2p f g P)"
unfolding bi_total_alt_def by simp

lemma vimage2p_eq [simp]:
  "inj f  BNF_Def.vimage2p f f (=) = (=)"
by(auto simp add: vimage2p_def fun_eq_iff inj_on_def)

lemma vimage2p_conversep: "BNF_Def.vimage2p f g R^--1 = (BNF_Def.vimage2p g f R)^--1"
by(simp add: vimage2p_def fun_eq_iff)

lemma rel_fun_refl: " A  (=); (=)  B   (=)  rel_fun A B"
  by(subst fun.rel_eq[symmetric])(rule fun_mono)

lemma rel_fun_mono_strong:
  " rel_fun A B f g; A'  A; x y.  x  f ` {x. Domainp A' x}; y  g ` {x. Rangep A' x}; B x y   B' x y   rel_fun A' B' f g"
  by(auto simp add: rel_fun_def) fastforce

lemma rel_fun_refl_strong: 
  assumes "A  (=)" "x. x  f ` {x. Domainp A x}  B x x"
  shows "rel_fun A B f f"
proof -
  have "rel_fun (=) (=) f f" by(simp add: rel_fun_eq)
  then show ?thesis using assms(1)
    by(rule rel_fun_mono_strong) (auto intro: assms(2))
qed

lemma Grp_iff: "BNF_Def.Grp B g x y  y = g x  x  B" by(simp add: Grp_def)

lemma Rangep_Grp: "Rangep (BNF_Def.Grp A f) = (λx. x  f ` A)"
  by(auto simp add: fun_eq_iff Grp_iff)

lemma rel_fun_Grp:
  "rel_fun (BNF_Def.Grp UNIV h)¯¯ (BNF_Def.Grp A g) = BNF_Def.Grp {f. f ` range h  A} (map_fun h g)"
  by(auto simp add: rel_fun_def fun_eq_iff Grp_iff)

subsection ‹Transfer and lifting material›

context includes lifting_syntax begin

lemma monotone_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> A ===> (=)) ===> (B ===> B ===> (=)) ===> (A ===> B) ===> (=)) monotone monotone"
unfolding monotone_def[abs_def] by transfer_prover

lemma fun_ord_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total C"
  shows "((A ===> B ===> (=)) ===> (C ===> A) ===> (C ===> B) ===> (=)) fun_ord fun_ord"
unfolding fun_ord_def[abs_def] by transfer_prover

lemma Plus_parametric [transfer_rule]:
  "(rel_set A ===> rel_set B ===> rel_set (rel_sum A B)) (<+>) (<+>)"
unfolding Plus_def[abs_def] by transfer_prover

lemma pred_fun_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> (=)) ===> (B ===> (=)) ===> (A ===> B) ===> (=)) pred_fun pred_fun"
unfolding pred_fun_def by(transfer_prover)

lemma rel_fun_eq_OO: "((=) ===> A) OO ((=) ===> B) = ((=) ===> A OO B)"
by(clarsimp simp add: rel_fun_def fun_eq_iff relcompp.simps) metis

end

lemma Quotient_set_rel_eq:
  includes lifting_syntax
  assumes "Quotient R Abs Rep T"
  shows "(rel_set T ===> rel_set T ===> (=)) (rel_set R) (=)"
proof(rule rel_funI iffI)+
  fix A B C D
  assume AB: "rel_set T A B" and CD: "rel_set T C D"
  have *: "x y. R x y = (T x (Abs x)  T y (Abs y)  Abs x = Abs y)"
    "a b. T a b  Abs a = b"
    using assms unfolding Quotient_alt_def by simp_all

  { assume [simp]: "B = D"
    thus "rel_set R A C"
      by(auto 4 4 intro!: rel_setI dest: rel_setD1[OF AB, simplified] rel_setD2[OF AB, simplified] rel_setD2[OF CD] rel_setD1[OF CD] simp add: * elim!: rev_bexI)
  next
    assume AC: "rel_set R A C"
    show "B = D"
      apply safe
       apply(drule rel_setD2[OF AB], erule bexE)
       apply(drule rel_setD1[OF AC], erule bexE)
       apply(drule rel_setD1[OF CD], erule bexE)
       apply(simp add: *)
      apply(drule rel_setD2[OF CD], erule bexE)
      apply(drule rel_setD2[OF AC], erule bexE)
      apply(drule rel_setD1[OF AB], erule bexE)
      apply(simp add: *)
      done
  }
qed

lemma Domainp_eq: "Domainp (=) = (λ_. True)"
by(simp add: Domainp.simps fun_eq_iff)

lemma rel_fun_eq_onpI: "eq_onp (pred_fun P Q) f g  rel_fun (eq_onp P) (eq_onp Q) f g"
by(auto simp add: eq_onp_def rel_fun_def)

lemma bi_unique_eq_onp: "bi_unique (eq_onp P)"
by(simp add: bi_unique_def eq_onp_def)

lemma rel_fun_eq_conversep: includes lifting_syntax shows "(A¯¯ ===> (=)) = (A ===> (=))¯¯"
by(auto simp add: fun_eq_iff rel_fun_def)

lemma rel_fun_comp:
  "f g h. rel_fun A B (f  g) h = rel_fun A (λx. B (f x)) g h"
  "f g h. rel_fun A B f (g  h) = rel_fun A (λx y. B x (g y)) f h"
  by(auto simp add: rel_fun_def)

lemma rel_fun_map_fun1: "rel_fun (BNF_Def.Grp UNIV h)¯¯ A f g  rel_fun (=) A (map_fun h id f) g"
  by(auto simp add: rel_fun_def Grp_def)

lemma map_fun2_id: "map_fun f g x = g  map_fun f id x"
  by(simp add: map_fun_def o_assoc)

lemma map_fun_id2_in: "map_fun g h f = map_fun g id (h  f)"
  by(simp add: map_fun_def)

lemma Domainp_rel_fun_le: "Domainp (rel_fun A B)  pred_fun (Domainp A) (Domainp B)"
  by(auto dest: rel_funD)

definition rel_witness_fun :: "('a  'b  bool)  ('b  'c  bool)  ('a  'd) × ('c  'e)  ('b  'd × 'e)" where
  "rel_witness_fun A A' = (λ(f, g) b. (f (THE a. A a b), g (THE c. A' b c)))"

lemma
  assumes fg: "rel_fun (A OO A') B f g"
    and A: "left_unique A" "right_total A"
    and A': "right_unique A'" "left_total A'"
  shows rel_witness_fun1: "rel_fun A (λx (x', y). x = x'  B x' y) f (rel_witness_fun A A' (f, g))"
    and rel_witness_fun2: "rel_fun A' (λ(x, y') y. y = y'  B x y') (rel_witness_fun A A' (f, g)) g"
proof (goal_cases)
  case 1
  have "A x y  f x = f (THE a. A a y)  B (f (THE a. A a y)) (g (The (A' y)))" for x y 
    by(rule left_totalE[OF A'(2)]; erule meta_allE[of _ y]; erule exE; frule (1) fg[THEN rel_funD, OF relcomppI])
      (auto intro!: arg_cong[where f=f] arg_cong[where f=g] rel_funI the_equality the_equality[symmetric] dest: left_uniqueD[OF A(1)] right_uniqueD[OF A'(1)] elim!: arg_cong2[where f=B, THEN iffD2, rotated -1])

  with 1 show ?case by(clarsimp simp add: rel_fun_def rel_witness_fun_def)
next
  case 2
  have "A' x y  g y = g (The (A' x))  B (f (THE a. A a x)) (g (The (A' x)))" for x y
    by(rule right_totalE[OF A(2), of x]; frule (1) fg[THEN rel_funD, OF relcomppI])
      (auto intro!: arg_cong[where f=f] arg_cong[where f=g] rel_funI the_equality the_equality[symmetric] dest: left_uniqueD[OF A(1)] right_uniqueD[OF A'(1)] elim!: arg_cong2[where f=B, THEN iffD2, rotated -1])

  with 2 show ?case by(clarsimp simp add: rel_fun_def rel_witness_fun_def)    
qed

lemma rel_witness_fun_eq [simp]: "rel_witness_fun (=) (=) (f, g) = (λx. (f x, g x))"
  by(simp add: rel_witness_fun_def)

subsection ‹Arithmetic›

lemma abs_diff_triangle_ineq2: "¦a - b :: _ :: ordered_ab_group_add_abs¦  ¦a - c¦ + ¦c - b¦"
by(rule order_trans[OF _ abs_diff_triangle_ineq]) simp

lemma (in ordered_ab_semigroup_add) add_left_mono_trans:
  " x  a + b; b  c   x  a + c"
by(erule order_trans)(rule add_left_mono)

lemma of_nat_le_one_cancel_iff [simp]:
  fixes n :: nat shows "real n  1  n  1"
by linarith

lemma (in linordered_semidom) mult_right_le: "c  1  0  a  c * a  a"
by(subst mult.commute)(rule mult_left_le)

subsection ‹Chain-complete partial orders and partial_function›

lemma fun_ordD: "fun_ord ord f g  ord (f x) (g x)"
by(simp add: fun_ord_def)

lemma parallel_fixp_induct_strong:
  assumes ccpo1: "class.ccpo luba orda (mk_less orda)"
  and ccpo2: "class.ccpo lubb ordb (mk_less ordb)"
  and adm: "ccpo.admissible (prod_lub luba lubb) (rel_prod orda ordb) (λx. P (fst x) (snd x))"
  and f: "monotone orda orda f"
  and g: "monotone ordb ordb g"
  and bot: "P (luba {}) (lubb {})"
  and step: "x y.  orda x (ccpo.fixp luba orda f); ordb y (ccpo.fixp lubb ordb g); P x y   P (f x) (g y)"
  shows "P (ccpo.fixp luba orda f) (ccpo.fixp lubb ordb g)"
proof -
  let ?P="λx y. orda x (ccpo.fixp luba orda f)  ordb y (ccpo.fixp lubb ordb g)  P x y"
  show ?thesis using ccpo1 ccpo2 _ f g
  proof(rule parallel_fixp_induct[where P="?P", THEN conjunct2, THEN conjunct2])
    note [cont_intro] = 
      admissible_leI[OF ccpo1] ccpo.mcont_const[OF ccpo1]
      admissible_leI[OF ccpo2] ccpo.mcont_const[OF ccpo2]
    show "ccpo.admissible (prod_lub luba lubb) (rel_prod orda ordb) (λxy. ?P (fst xy) (snd xy))"
      using adm by simp
    show "?P (luba {}) (lubb {})" using bot by(auto intro: ccpo.ccpo_Sup_least ccpo1 ccpo2 chain_empty)
    show "?P (f x) (g y)" if "?P x y" for x y using that
      apply(subst ccpo.fixp_unfold[OF ccpo1 f])
      apply(subst ccpo.fixp_unfold[OF ccpo2 g])
      apply(auto intro: step monotoneD[OF f] monotoneD[OF g])
      done
  qed
qed

lemma parallel_fixp_induct_strong_uc:
  assumes a: "partial_function_definitions orda luba"
  and b: "partial_function_definitions ordb lubb"
  and F: "x. monotone (fun_ord orda) orda (λf. U1 (F (C1 f)) x)"
  and G: "y. monotone (fun_ord ordb) ordb (λg. U2 (G (C2 g)) y)"
  and eq1: "f  C1 (ccpo.fixp (fun_lub luba) (fun_ord orda) (λf. U1 (F (C1 f))))"
  and eq2: "g  C2 (ccpo.fixp (fun_lub lubb) (fun_ord ordb) (λg. U2 (G (C2 g))))"
  and inverse: "f. U1 (C1 f) = f"
  and inverse2: "g. U2 (C2 g) = g"
  and adm: "ccpo.admissible (prod_lub (fun_lub luba) (fun_lub lubb)) (rel_prod (fun_ord orda) (fun_ord ordb)) (λx. P (fst x) (snd x))"
  and bot: "P (λ_. luba {}) (λ_. lubb {})"
  and step: "f' g'.  x. orda (U1 f' x) (U1 f x); y. ordb (U2 g' y) (U2 g y); P (U1 f') (U2 g')   P (U1 (F f')) (U2 (G g'))"
  shows "P (U1 f) (U2 g)"
apply(unfold eq1 eq2 inverse inverse2)
apply(rule parallel_fixp_induct_strong[OF partial_function_definitions.ccpo[OF a] partial_function_definitions.ccpo[OF b] adm])
using F apply(simp add: monotone_def fun_ord_def)
using G apply(simp add: monotone_def fun_ord_def)
apply(simp add: fun_lub_def bot)
apply(rule step; simp add: inverse inverse2 eq1 eq2 fun_ordD)
done

lemmas parallel_fixp_induct_strong_1_1 = parallel_fixp_induct_strong_uc[
  of _ _ _ _ "λx. x" _ "λx. x" "λx. x" _ "λx. x",
  OF _ _ _ _ _ _ refl refl]

lemmas parallel_fixp_induct_strong_2_2 = parallel_fixp_induct_strong_uc[
  of _ _ _ _ "case_prod" _ "curry" "case_prod" _ "curry",
  where P="λf g. P (curry f) (curry g)",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl,
  split_format (complete), unfolded prod.case]
  for P

lemma fixp_induct_option': ― ‹Stronger induction rule›
  fixes F :: "'c  'c" and
    U :: "'c  'b  'a option" and
    C :: "('b  'a option)  'c" and
    P :: "'b  'a  bool"
  assumes mono: "x. mono_option (λf. U (F (C f)) x)"
  assumes eq: "f  C (ccpo.fixp (fun_lub (flat_lub None)) (fun_ord option_ord) (λf. U (F (C f))))"
  assumes inverse2: "f. U (C f) = f"
  assumes step: "g x y.  x y. U g x = Some y  P x y; U (F g) x = Some y; x. option_ord (U g x) (U f x)   P x y"
  assumes defined: "U f x = Some y"
  shows "P x y"
using step defined option.fixp_strong_induct_uc[of U F C, OF mono eq inverse2 option_admissible, of P]
unfolding fun_lub_def flat_lub_def fun_ord_def
by(simp (no_asm_use)) blast

declaration Partial_Function.init "option'" @{term option.fixp_fun}
  @{term option.mono_body} @{thm option.fixp_rule_uc} @{thm option.fixp_induct_uc}
  (SOME @{thm fixp_induct_option'})

lemma bot_fun_least [simp]: "(λ_. bot :: 'a :: order_bot)  x"
by(fold bot_fun_def) simp

lemma fun_ord_conv_rel_fun: "fun_ord = rel_fun (=)"
by(simp add: fun_ord_def fun_eq_iff rel_fun_def)

inductive finite_chains :: "('a  'a  bool)  bool"
  for ord
where finite_chainsI: "(Y. Complete_Partial_Order.chain ord Y  finite Y)  finite_chains ord"

lemma finite_chainsD: " finite_chains ord; Complete_Partial_Order.chain ord Y   finite Y"
by(rule finite_chains.cases)

lemma finite_chains_flat_ord [simp, intro!]: "finite_chains (flat_ord x)"
proof
  fix Y
  assume chain: "Complete_Partial_Order.chain (flat_ord x) Y"
  show "finite Y"
  proof(cases "y  Y. y  x")
    case True
    then obtain y where y: "y  Y" and yx: "y  x" by blast
    hence "Y  {x, y}" by(auto dest: chainD[OF chain] simp add: flat_ord_def)
    thus ?thesis by(rule finite_subset) simp
  next
    case False
    hence "Y  {x}" by auto
    thus ?thesis by(rule finite_subset) simp
  qed
qed    

lemma mcont_finite_chains:
  assumes finite: "finite_chains ord"
  and mono: "monotone ord ord' f"
  and ccpo: "class.ccpo lub ord (mk_less ord)"
  and ccpo': "class.ccpo lub' ord' (mk_less ord')"
  shows "mcont lub ord lub' ord' f"
proof(intro mcontI contI)
  fix Y
  assume chain: "Complete_Partial_Order.chain ord Y" and Y: "Y  {}"
  from finite chain have fin: "finite Y" by(rule finite_chainsD)
  from ccpo chain fin Y have lub: "lub Y  Y" by(rule ccpo.in_chain_finite)

  interpret ccpo': ccpo lub' ord' "mk_less ord'" by(rule ccpo')

  have chain': "Complete_Partial_Order.chain ord' (f ` Y)" using chain
    by(rule chain_imageI)(rule monotoneD[OF mono])

  have "ord' (f (lub Y)) (lub' (f ` Y))" using chain'
    by(rule ccpo'.ccpo_Sup_upper)(simp add: lub)
  moreover
  have "ord' (lub' (f ` Y)) (f (lub Y))" using chain'
    by(rule ccpo'.ccpo_Sup_least)(blast intro: monotoneD[OF mono] ccpo.ccpo_Sup_upper[OF ccpo chain])
  ultimately show "f (lub Y) = lub' (f ` Y)" by(rule ccpo'.order.antisym)
qed(fact mono)  

lemma rel_fun_curry: includes lifting_syntax shows
  "(A ===> B ===> C) f g  (rel_prod A B ===> C) (case_prod f) (case_prod g)"
by(auto simp add: rel_fun_def)

lemma (in ccpo) Sup_image_mono:
  assumes ccpo: "class.ccpo luba orda lessa"
  and mono: "monotone orda (≤) f"
  and chain: "Complete_Partial_Order.chain orda A"
  and "A  {}"
  shows "Sup (f ` A)  (f (luba A))"
proof(rule ccpo_Sup_least)
  from chain show "Complete_Partial_Order.chain (≤) (f ` A)"
    by(rule chain_imageI)(rule monotoneD[OF mono])
  fix x
  assume "x  f ` A"
  then obtain y where "x = f y" "y  A" by blast
  from y  A have "orda y (luba A)" by(rule ccpo.ccpo_Sup_upper[OF ccpo chain])
  hence "f y  f (luba A)" by(rule monotoneD[OF mono])
  thus "x  f (luba A)" using x = f y by simp
qed

lemma (in ccpo) admissible_le_mono:
  assumes "monotone (≤) (≤) f"
  shows "ccpo.admissible Sup (≤) (λx. x  f x)"
proof(rule ccpo.admissibleI)
  fix Y
  assume chain: "Complete_Partial_Order.chain (≤) Y"
    and Y: "Y  {}"
    and le [rule_format]: "xY. x  f x"
  have "Y  (f ` Y)" using chain
    by(rule ccpo_Sup_least)(rule order_trans[OF le]; blast intro!: ccpo_Sup_upper chain_imageI[OF chain] intro: monotoneD[OF assms])
  also have "  f (Y)"
    by(rule Sup_image_mono[OF _ assms chain Y, where lessa="(<)"]) unfold_locales
  finally show "Y  " .
qed

lemma (in ccpo) fixp_induct_strong2:
  assumes adm: "ccpo.admissible Sup (≤) P"
  and mono: "monotone (≤) (≤) f"
  and bot: "P ({})"
  and step: "x.  x  ccpo_class.fixp f; x  f x; P x   P (f x)"
  shows "P (ccpo_class.fixp f)"
proof(rule fixp_strong_induct[where P="λx. x  f x  P x", THEN conjunct2])
  show "ccpo.admissible Sup (≤) (λx. x  f x  P x)"
    using admissible_le_mono adm by(rule admissible_conj)(rule mono)
next
  show "{}  f ({})  P ({})"
    by(auto simp add: bot chain_empty intro: ccpo_Sup_least)
next
  fix x
  assume "x  ccpo_class.fixp f" "x  f x  P x"
  thus "f x  f (f x)  P (f x)"
    by(auto dest: monotoneD[OF mono] intro: step)
qed(rule mono)

context partial_function_definitions begin

lemma fixp_induct_strong2_uc:
  fixes F :: "'c  'c"
    and U :: "'c  'b  'a"
    and C :: "('b  'a)  'c"
    and P :: "('b  'a)  bool"
  assumes mono: "x. mono_body (λf. U (F (C f)) x)"
    and eq: "f  C (fixp_fun (λf. U (F (C f))))"
    and inverse: "f. U (C f) = f"
    and adm: "ccpo.admissible lub_fun le_fun P"
    and bot: "P (λ_. lub {})"
    and step: "f'.  le_fun (U f') (U f); le_fun (U f') (U (F f')); P (U f')   P (U (F f'))"
  shows "P (U f)"
unfolding eq inverse
apply (rule ccpo.fixp_induct_strong2[OF ccpo adm])
apply (insert mono, auto simp: monotone_def fun_ord_def bot fun_lub_def)[2]
apply (rule_tac f'5="C x" in step)
apply (simp_all add: inverse eq)
done

end

lemmas parallel_fixp_induct_2_4 = parallel_fixp_induct_uc[
  of _ _ _ _ "case_prod" _ "curry" "λf. case_prod (case_prod (case_prod f))" _ "λf. curry (curry (curry f))",
  where P="λf g. P (curry f) (curry (curry (curry g)))",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl]
  for P
  
lemma (in ccpo) fixp_greatest:
  assumes f: "monotone (≤) (≤) f"
    and ge: "y. f y  y  x  y"
  shows "x  ccpo.fixp Sup (≤) f"
  by(rule ge)(simp add: fixp_unfold[OF f, symmetric])

lemma fixp_rolling:
  assumes "class.ccpo lub1 leq1 (mk_less leq1)"
    and "class.ccpo lub2 leq2 (mk_less leq2)"
    and f: "monotone leq1 leq2 f"
    and g: "monotone leq2 leq1 g"
  shows "ccpo.fixp lub1 leq1 (λx. g (f x)) = g (ccpo.fixp lub2 leq2 (λx. f (g x)))"
proof -
  interpret c1: ccpo lub1 leq1 "mk_less leq1" by fact
  interpret c2: ccpo lub2 leq2 "mk_less leq2" by fact
  show ?thesis
  proof(rule c1.order.antisym)
    have fg: "monotone leq2 leq2 (λx. f (g x))" using f g by(rule monotone2monotone) simp_all
    have gf: "monotone leq1 leq1 (λx. g (f x))" using g f by(rule monotone2monotone) simp_all
    show "leq1 (c1.fixp (λx. g (f x))) (g (c2.fixp (λx. f (g x))))" using gf
      by(rule c1.fixp_lowerbound)(subst (2) c2.fixp_unfold[OF fg], simp)
    show "leq1 (g (c2.fixp (λx. f (g x)))) (c1.fixp (λx. g (f x)))" using gf
    proof(rule c1.fixp_greatest)
      fix u
      assume u: "leq1 (g (f u)) u"
      have "leq1 (g (c2.fixp (λx. f (g x)))) (g (f u))"
        by(intro monotoneD[OF g] c2.fixp_lowerbound[OF fg] monotoneD[OF f u])
      then show "leq1 (g (c2.fixp (λx. f (g x)))) u" using u by(rule c1.order_trans)
    qed
  qed
qed

lemma fixp_lfp_parametric_eq:
  includes lifting_syntax
  assumes f: "x. lfp.mono_body (λf. F f x)"
  and g: "x. lfp.mono_body (λf. G f x)"
  and param: "((A ===> (=)) ===> A ===> (=)) F G"
  shows "(A ===> (=)) (lfp.fixp_fun F) (lfp.fixp_fun G)"
using f g
proof(rule parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions _ _ reflexive reflexive, where P="(A ===> (=))"])
  show "ccpo.admissible (prod_lub lfp.lub_fun lfp.lub_fun) (rel_prod lfp.le_fun lfp.le_fun) (λx. (A ===> (=)) (fst x) (snd x))"
    unfolding rel_fun_def by simp
  show "(A ===> (=)) (λ_. {}) (λ_. {})" by auto
  show "(A ===> (=)) (F f) (G g)" if "(A ===> (=)) f g" for f g
    using that by(rule rel_funD[OF param])
qed

lemma mono2mono_map_option[THEN option.mono2mono, simp, cont_intro]:
  shows monotone_map_option: "monotone option_ord option_ord (map_option f)"
by(rule monotoneI)(auto simp add: flat_ord_def)

lemma mcont2mcont_map_option[THEN option.mcont2mcont, simp, cont_intro]:
  shows mcont_map_option: "mcont (flat_lub None) option_ord (flat_lub None) option_ord (map_option f)"
by(rule mcont_finite_chains[OF _ _ flat_interpretation[THEN ccpo] flat_interpretation[THEN ccpo]]) simp_all

lemma mono2mono_set_option [THEN lfp.mono2mono]:
  shows monotone_set_option: "monotone option_ord (⊆) set_option"
by(auto intro!: monotoneI simp add: option_ord_Some1_iff)

lemma mcont2mcont_set_option [THEN lfp.mcont2mcont, cont_intro, simp]:
  shows mcont_set_option: "mcont (flat_lub None) option_ord Union (⊆) set_option"
by(rule mcont_finite_chains)(simp_all add: monotone_set_option ccpo option.partial_function_definitions_axioms)

lemma eadd_gfp_partial_function_mono [partial_function_mono]:
  " monotone (fun_ord (≥)) (≥) f; monotone (fun_ord (≥)) (≥) g 
   monotone (fun_ord (≥)) (≥) (λx. f x + g x :: enat)"
by(rule mono2mono_gfp_eadd)

lemma map_option_mono [partial_function_mono]:
  "mono_option B  mono_option (λf. map_option g (B f))"
unfolding map_conv_bind_option by(rule bind_mono) simp_all


subsection ‹Folding over finite sets›

lemma (in comp_fun_commute) fold_invariant_remove [consumes 1, case_names start step]:
  assumes fin: "finite A"
  and start: "I A s"
  and step: "x s A'.  x  A'; I A' s; A'  A   I (A' - {x}) (f x s)"
  shows "I {} (Finite_Set.fold f s A)"
proof -
  define A' where "A' == A"
  with fin start have "finite A'" "A'  A" "I A' s" by simp_all
  thus "I {} (Finite_Set.fold f s A')"
  proof(induction arbitrary: s)
    case empty thus ?case by simp
  next
    case (insert x A')
    let ?A' = "insert x A'"
    have "x  ?A'" "I ?A' s" "?A'  A" using insert by auto
    hence "I (?A' - {x}) (f x s)" by(rule step)
    with insert have "A'  A" "I A' (f x s)" by auto
    hence "I {} (Finite_Set.fold f (f x s) A')" by(rule insert.IH)
    thus ?case using insert by(simp add: fold_insert2 del: fold_insert)
  qed
qed

lemma (in comp_fun_commute) fold_invariant_insert [consumes 1, case_names start step]:
  assumes fin: "finite A"
  and start: "I {} s"
  and step: "x s A'.  I A' s; x  A'; x  A; A'  A   I (insert x A') (f x s)"
  shows "I A (Finite_Set.fold f s A)"
using fin start
proof(rule fold_invariant_remove[where I="λA'. I (A - A')" and A=A and s=s, simplified])
  fix x s A'
  assume *: "x  A'" "I (A - A') s" "A'  A"
  hence "x  A - A'" "x  A" "A - A'  A" by auto
  with I (A - A') s have "I (insert x (A - A')) (f x s)" by(rule step)
  also have "insert x (A - A') = A - (A' - {x})" using * by auto
  finally show "I  (f x s)" .
qed

lemma (in comp_fun_idem) fold_set_union:
  assumes "finite A" "finite B"
  shows "Finite_Set.fold f z (A  B) = Finite_Set.fold f (Finite_Set.fold f z A) B"
using assms(2,1) by induction simp_all


subsection ‹Parametrisation of transfer rules›

attribute_setup transfer_parametric = Attrib.thm >> (fn parametricity =>
    Thm.rule_attribute [] (fn context => fn transfer_rule =>
      let
        val ctxt = Context.proof_of context;
        val thm' = Lifting_Term.parametrize_transfer_rule ctxt transfer_rule
      in Lifting_Def.generate_parametric_transfer_rule ctxt thm' parametricity
      end
      handle Lifting_Term.MERGE_TRANSFER_REL msg => error (Pretty.string_of msg)
      )) "combine transfer rule with parametricity theorem"

subsection ‹Lists›

lemma nth_eq_tlI: "xs ! n = z  (x # xs) ! Suc n = z"
by simp

lemma list_all2_append':
  "length us = length vs  list_all2 P (xs @ us) (ys @ vs)  list_all2 P xs ys  list_all2 P us vs"
by(auto simp add: list_all2_append1 list_all2_append2 dest: list_all2_lengthD)

definition disjointp :: "('a  bool) list  bool"
where "disjointp xs = disjoint_family_on (λn. {x. (xs ! n) x}) {0..<length xs}"

lemma disjointpD:
  " disjointp xs; (xs ! n) x; (xs ! m) x; n < length xs; m < length xs   n = m"
by(auto 4 3 simp add: disjointp_def disjoint_family_on_def)

lemma disjointpD':
  " disjointp xs; P x; Q x; xs ! n = P; xs ! m = Q; n < length xs; m < length xs   n = m"
by(auto 4 3 simp add: disjointp_def disjoint_family_on_def)

lemma wf_strict_prefix: "wfP strict_prefix"
proof -
  from wf have "wf (inv_image {(x, y). x < y} length)" by(rule wf_inv_image)
  moreover have "{(x, y). strict_prefix x y}  inv_image {(x, y). x < y} length" by(auto intro: prefix_length_less)
  ultimately show ?thesis unfolding wfP_def by(rule wf_subset)
qed

lemma strict_prefix_setD:
  "strict_prefix xs ys  set xs  set ys"
  by(auto simp add: strict_prefix_def prefix_def)

subsubsection ‹List of a given length›

inductive_set nlists :: "'a set  nat  'a list set" for A n
where nlists: " set xs  A; length xs = n   xs  nlists A n"
hide_fact (open) nlists

lemma nlists_alt_def: "nlists A n = {xs. set xs  A  length xs = n}"
by(auto simp add: nlists.simps)

lemma nlists_empty: "nlists {} n = (if n = 0 then {[]} else {})"
by(auto simp add: nlists_alt_def)

lemma nlists_empty_gt0 [simp]: "n > 0  nlists {} n = {}"
by(simp add: nlists_empty)

lemma nlists_0 [simp]: "nlists A 0 = {[]}"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_Suc [simp]: "x # xs  nlists A (Suc n)  x  A  xs  nlists A n"
by(simp add: nlists_alt_def)

lemma Nil_in_nlists [simp]: "[]  nlists A n  n = 0"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_iff: "x # xs  nlists A n  (n'. n = Suc n'  x  A  xs  nlists A n')"
by(cases n) simp_all

lemma in_nlists_Suc_iff: "xs  nlists A (Suc n)  (x xs'. xs = x # xs'  x  A  xs'  nlists A n)"
by(cases xs) simp_all

lemma nlists_Suc: "nlists A (Suc n) = (xA. (#) x ` nlists A n)"
by(auto 4 3 simp add: in_nlists_Suc_iff intro: rev_image_eqI)

lemma replicate_in_nlists [simp, intro]: "x  A  replicate n x  nlists A n"
by(simp add: nlists_alt_def set_replicate_conv_if)

lemma nlists_eq_empty_iff [simp]: "nlists A n = {}  n > 0  A = {}"
using replicate_in_nlists by(cases n)(auto)

lemma finite_nlists [simp]: "finite A  finite (nlists A n)"
by(induction n)(simp_all add: nlists_Suc)

lemma finite_nlistsD: 
  assumes "finite (nlists A n)"
  shows "finite A  n = 0"
proof(rule disjCI)
  assume "n  0"
  then obtain n' where n: "n = Suc n'" by(cases n)auto
  then have "A = hd ` nlists A n" by(auto 4 4 simp add: nlists_Suc intro: rev_image_eqI rev_bexI)
  also have "finite " using assms ..
  finally show "finite A" .
qed

lemma finite_nlists_iff: "finite (nlists A n)  finite A  n = 0"
by(auto dest: finite_nlistsD)

lemma card_nlists: "card (nlists A n) = card A ^ n"
proof(induction n)
  case (Suc n)
  have "card (xA. (#) x ` nlists A n) = card A * card (nlists A n)"
  proof(cases "finite A")
    case True
    then show ?thesis by(subst card_UN_disjoint)(auto simp add: card_image inj_on_def)
  next
    case False
    hence "¬ finite (xA. (#) x ` nlists A n)"
      unfolding nlists_Suc[symmetric] by(auto dest: finite_nlistsD)
    then show ?thesis using False by simp
  qed
  then show ?case using Suc.IH by(simp add: nlists_Suc)
qed simp

lemma in_nlists_UNIV: "xs  nlists UNIV n  length xs = n"
by(simp add: nlists_alt_def)

subsubsection ‹ The type of lists of a given length ›

typedef (overloaded) ('a, 'b :: len0) nlist = "nlists (UNIV :: 'a set) (LENGTH('b))"
proof
  show "replicate LENGTH('b) undefined  ?nlist" by simp
qed

setup_lifting type_definition_nlist

subsection ‹Streams and infinite lists›

primrec sprefix :: "'a list  'a stream  bool" where
  sprefix_Nil: "sprefix [] ys = True"
| sprefix_Cons: "sprefix (x # xs) ys  x = shd ys  sprefix xs (stl ys)"

lemma sprefix_append: "sprefix (xs @ ys) zs  sprefix xs zs  sprefix ys (sdrop (length xs) zs)"
by(induct xs arbitrary: zs) simp_all

lemma sprefix_stake_same [simp]: "sprefix (stake n xs) xs"
by(induct n arbitrary: xs) simp_all

lemma sprefix_same_imp_eq:
  assumes "sprefix xs ys" "sprefix xs' ys"
  and "length xs = length xs'"
  shows "xs = xs'"
using assms(3,1,2) by(induct arbitrary: ys rule: list_induct2) auto

lemma sprefix_shift_same [simp]:
  "sprefix xs (xs @- ys)"
by(induct xs) simp_all

lemma sprefix_shift [simp]:
  "length xs  length ys  sprefix xs (ys @- zs)  prefix xs ys"
by(induct xs arbitrary: ys)(simp, case_tac ys, auto)

lemma prefixeq_stake2 [simp]: "prefix xs (stake n ys)  length xs  n  sprefix xs ys"
proof(induct xs arbitrary: n ys)
  case (Cons x xs)
  thus ?case by(cases ys n rule: stream.exhaust[case_product nat.exhaust]) auto
qed simp

lemma tlength_eq_infinity_iff: "tlength xs =   ¬ tfinite xs"
including tllist.lifting by transfer(simp add: llength_eq_infty_conv_lfinite)

subsection ‹Monomorphic monads›

context includes lifting_syntax begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "monad")

definition bind_option :: "'m fail  'a option  ('a  'm)  'm"
where "bind_option fail x f = (case x of None  fail | Some x'  f x')" for fail

simps_of_case bind_option_simps [simp]: bind_option_def

lemma bind_option_parametric [transfer_rule]:
  "(M ===> rel_option B ===> (B ===> M) ===> M) bind_option bind_option"
unfolding bind_option_def by transfer_prover

lemma bind_option_K:
  "monad. (x = None  m = fail)  bind_option fail x (λ_. m) = m"
by(cases x) simp_all

end

lemma bind_option_option [simp]: "monad.bind_option None = Option.bind"
by(simp add: monad.bind_option_def fun_eq_iff split: option.split)

context monad_fail_hom begin

lemma hom_bind_option: "h (monad.bind_option fail1 x f) = monad.bind_option fail2 x (h  f)"
by(cases x)(simp_all)

end

lemma bind_option_set [simp]: "monad.bind_option fail_set = (λx f.  (f ` set_option x))"
by(simp add: monad.bind_option_def fun_eq_iff split: option.split)

lemma run_bind_option_stateT [simp]:
  "more. run_state (monad.bind_option (fail_state fail) x f) s = 
  monad.bind_option fail x (λy. run_state (f y) s)"
by(cases x) simp_all

lemma run_bind_option_envT [simp]:
  "more. run_env (monad.bind_option (fail_env fail) x f) s = 
  monad.bind_option fail x (λy. run_env (f y) s)"
by(cases x) simp_all


subsection ‹Measures›

declare sets_restrict_space_count_space [measurable_cong]

lemma (in sigma_algebra) sets_Collect_countable_Ex1:
  "(i :: 'i :: countable. {x  Ω. P i x}  M)  {x  Ω. ∃!i. P i x}  M"
using sets_Collect_countable_Ex1'[of "UNIV :: 'i set"] by simp

lemma pred_countable_Ex1 [measurable]:
  "(i :: _ :: countable. Measurable.pred M (λx. P i x))
   Measurable.pred M (λx. ∃!i. P i x)"
unfolding pred_def by(rule sets.sets_Collect_countable_Ex1)

lemma measurable_snd_count_space [measurable]: 
  "A  B  snd  measurable (M1 M count_space A) (count_space B)"
by(auto simp add: measurable_def space_pair_measure snd_vimage_eq_Times Times_Int_Times)

lemma integrable_scale_measure [simp]:
  " integrable M f; r <    integrable (scale_measure r M) f" 
  for f :: "'a  'b::{banach, second_countable_topology}"
  by(auto simp add: integrable_iff_bounded nn_integral_scale_measure ennreal_mult_less_top)

lemma integral_scale_measure:
  assumes "integrable M f" "r < "
  shows "integralL (scale_measure r M) f = enn2real r * integralL M f"
  using assms
  apply(subst (1 2) real_lebesgue_integral_def)
    apply(simp_all add: nn_integral_scale_measure ennreal_enn2real_if)
  by(auto simp add: ennreal_mult_less_top ennreal_less_top_iff ennreal_mult_eq_top_iff enn2real_mult right_diff_distrib elim!: integrableE)

subsection ‹Sequence space›

lemma (in sequence_space) nn_integral_split:
  assumes f[measurable]: "f  borel_measurable S"
  shows "(+ω. f ω S) = (+ω. (+ω'. f (comb_seq i ω ω') S) S)"
by (subst PiM_comb_seq[symmetric, where i=i])
   (simp add: nn_integral_distr P.nn_integral_fst[symmetric])

lemma (in sequence_space) prob_Collect_split:
  assumes f[measurable]: "{xspace S. P x}  sets S"
  shows "𝒫(x in S. P x) = (+x. 𝒫(x' in S. P (comb_seq i x x')) S)"
proof -
  have "𝒫(x in S. P x) = (+x. (+x'. indicator {xspace S. P x} (comb_seq i x x') S) S)"
    using nn_integral_split[of "indicator {xspace S. P x}"] by (auto simp: emeasure_eq_measure)
  also have " = (+x. 𝒫(x' in S. P (comb_seq i x x')) S)"
    by (intro nn_integral_cong) (auto simp: emeasure_eq_measure nn_integral_indicator_map)
  finally show ?thesis .
qed

subsection ‹Probability mass functions›

lemma measure_map_pmf_conv_distr:
  "measure_pmf (map_pmf f p) = distr (measure_pmf p) (count_space UNIV) f"
by(fact map_pmf_rep_eq)

abbreviation coin_pmf :: "bool pmf" where "coin_pmf  pmf_of_set UNIV"

text ‹The rule @{thm [source] rel_pmf_bindI} is not complete as a program logic.›
notepad begin
  define x where "x = pmf_of_set {True, False}"
  define y where "y = pmf_of_set {True, False}"
  define f where "f x = pmf_of_set {True, False}" for x :: bool
  define g :: "bool  bool pmf" where "g = return_pmf"
  define P :: "bool  bool  bool" where "P = (=)"
  have "rel_pmf P (bind_pmf x f) (bind_pmf y g)"
    by(simp add: P_def f_def[abs_def] g_def y_def bind_return_pmf' pmf.rel_eq)
  have "¬ R x y" if "x y. R x y  rel_pmf P (f x) (g y)" for R x y
    ― ‹Only the empty relation satisfies @{thm [source] rel_pmf_bindI}'s second premise.›
  proof
    assume "R x y"
    hence "rel_pmf P (f x) (g y)" by(rule that)
    thus False by(auto simp add: P_def f_def g_def rel_pmf_return_pmf2)
  qed
  define R where "R x y = False" for x y :: bool
  have "¬ rel_pmf R x y" by(simp add: R_def[abs_def])
end

lemma pred_rel_pmf:
  " pred_pmf P p; rel_pmf R p q   pred_pmf (Imagep R P) q"
unfolding pred_pmf_def
apply(rule ballI)
apply(unfold rel_pmf.simps)
apply(erule exE conjE)+
apply hypsubst
apply(unfold pmf.set_map)
apply(erule imageE, hypsubst)
apply(drule bspec)
 apply(erule rev_image_eqI)
 apply(rule refl)
apply(erule Imagep.intros)
apply(erule allE)+
 apply(erule mp)
apply(unfold prod.collapse)
apply assumption
done

lemma pmf_rel_mono': " rel_pmf P x y; P  Q   rel_pmf Q x y"
by(drule pmf.rel_mono) (auto)

lemma rel_pmf_eqI [simp]: "rel_pmf (=) x x"
by(simp add: pmf.rel_eq)

lemma rel_pmf_bind_reflI:
  "(x. x  set_pmf p  rel_pmf R (f x) (g x))
   rel_pmf R (bind_pmf p f) (bind_pmf p g)"
by(rule rel_pmf_bindI[where R="λx y. x = y  x  set_pmf p"])(auto intro: rel_pmf_reflI)

lemma pmf_pred_mono_strong:
  " pred_pmf P p; a.  a  set_pmf p; P a   P' a   pred_pmf P' p"
by(simp add: pred_pmf_def)

lemma rel_pmf_restrict_relpI [intro?]:
  " rel_pmf R x y; pred_pmf P x; pred_pmf Q y   rel_pmf (R  P  Q) x y"
by(erule pmf.rel_mono_strong)(simp add: pred_pmf_def)

lemma rel_pmf_restrict_relpE [elim?]:
  assumes "rel_pmf (R  P  Q) x y"
  obtains "rel_pmf R x y" "pred_pmf P x" "pred_pmf Q y"
proof
  show "rel_pmf R x y" using assms by(auto elim!: pmf.rel_mono_strong)
  have "pred_pmf (Domainp (R  P  Q)) x" using assms by(fold pmf.Domainp_rel) blast
  then show "pred_pmf P x" by(rule pmf_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_pmf (Domainp (R  P  Q)¯¯) y" using assms
    by(fold pmf.Domainp_rel)(auto simp only: pmf.rel_conversep Domainp_conversep)
  then show "pred_pmf Q y" by(rule pmf_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_pmf_restrict_relp_iff:
  "rel_pmf (R  P  Q) x y  rel_pmf R x y  pred_pmf P x  pred_pmf Q y"
by(blast intro: rel_pmf_restrict_relpI elim: rel_pmf_restrict_relpE)

lemma rel_pmf_OO_trans [trans]:
  " rel_pmf R p q; rel_pmf S q r   rel_pmf (R OO S) p r"
unfolding pmf.rel_compp by blast

lemma pmf_pred_map [simp]: "pred_pmf P (map_pmf f p) = pred_pmf (P  f) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_bind [simp]: "pred_pmf P (bind_pmf p f) = pred_pmf (pred_pmf P  f) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_return [simp]: "pred_pmf P (return_pmf x) = P x"
by(simp add: pred_pmf_def)

lemma pred_pmf_of_set [simp]: " finite A; A  {}   pred_pmf P (pmf_of_set A) = Ball A P"
by(simp add: pred_pmf_def)

lemma pred_pmf_of_multiset [simp]: "M  {#}  pred_pmf P (pmf_of_multiset M) = Ball (set_mset M) P"
by(simp add: pred_pmf_def)

lemma pred_pmf_cond [simp]:
  "set_pmf p  A  {}  pred_pmf P (cond_pmf p A) = pred_pmf (λx. x  A  P x) p"
by(auto simp add: pred_pmf_def)

lemma pred_pmf_pair [simp]:
  "pred_pmf P (pair_pmf p q) = pred_pmf (λx. pred_pmf (P  Pair x) q) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_join [simp]: "pred_pmf P (join_pmf p) = pred_pmf (pred_pmf P) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_bernoulli [simp]: " 0 < p; p < 1   pred_pmf P (bernoulli_pmf p) = All P"
by(simp add: pred_pmf_def)

lemma pred_pmf_geometric [simp]: " 0 < p; p < 1   pred_pmf P (geometric_pmf p) = All P"
by(simp add: pred_pmf_def set_pmf_geometric)

lemma pred_pmf_poisson [simp]: "0 < rate  pred_pmf P (poisson_pmf rate) = All P"
by(simp add: pred_pmf_def)

lemma pmf_rel_map_restrict_relp: 
  shows pmf_rel_map_restrict_relp1: "rel_pmf (R  P  Q) (map_pmf f p) = rel_pmf (R  f  P  f  Q) p"
  and pmf_rel_map_restrict_relp2: "rel_pmf (R  P  Q) p (map_pmf g q) = rel_pmf ((λx. R x  g)  P  Q  g) p q"
by(simp_all add: pmf.rel_map restrict_relp_def fun_eq_iff)

lemma pred_pmf_conj [simp]: "pred_pmf (λx. P x  Q x) = (λx. pred_pmf P x  pred_pmf Q x)"
by(auto simp add: pred_pmf_def)

lemma pred_pmf_top [simp]:
  "pred_pmf (λ_. True) = (λ_. True)"
by(simp add: pred_pmf_def)

lemma rel_pmf_of_setI:
  assumes A: "A  {}" "finite A"
  and B: "B  {}" "finite B"
  and card: "X. X  A  card B * card X  card A * card {yB. xX. R x y}"
  shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
apply(rule rel_pmf_measureI)
using assms
apply(clarsimp simp add: measure_pmf_of_set card_gt_0_iff field_simps of_nat_mult[symmetric] simp del: of_nat_mult)
apply(subst mult.commute)
apply(erule meta_allE)
apply(erule meta_impE)
 prefer 2
 apply(erule order_trans)
apply(auto simp add: card_gt_0_iff intro: card_mono)
done

consts rel_witness_pmf :: "('a  'b  bool)  'a pmf × 'b pmf  ('a × 'b) pmf"
specification (rel_witness_pmf)
  set_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  set_pmf (rel_witness_pmf A xy)  {(a, b). A a b}"
  map1_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  map_pmf fst (rel_witness_pmf A xy) = fst xy"
  map2_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  map_pmf snd (rel_witness_pmf A xy) = snd xy"
  apply(fold all_conj_distrib imp_conjR)
  apply(rule choice allI)+
  apply(unfold pmf.in_rel)
  by blast

lemmas set_rel_witness_pmf = set_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas map1_rel_witness_pmf = map1_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas map2_rel_witness_pmf = map2_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas rel_witness_pmf = set_rel_witness_pmf map1_rel_witness_pmf map2_rel_witness_pmf

lemma rel_witness_pmf1:
  assumes "rel_pmf A p q" 
  shows "rel_pmf (λa (a', b). a = a'  A a' b) p (rel_witness_pmf A (p, q))"
  using map1_rel_witness_pmf[OF assms, symmetric]
  unfolding pmf.rel_eq[symmetric] pmf.rel_map
  by(rule pmf.rel_mono_strong)(auto dest: set_rel_witness_pmf[OF assms, THEN subsetD])

lemma rel_witness_pmf2:
  assumes "rel_pmf A p q" 
  shows "rel_pmf (λ(a, b') b. b = b'  A a b') (rel_witness_pmf A (p, q)) q"
  using map2_rel_witness_pmf[OF assms]
  unfolding pmf.rel_eq[symmetric] pmf.rel_map
  by(rule pmf.rel_mono_strong)(auto dest: set_rel_witness_pmf[OF assms, THEN subsetD])

lemma cond_pmf_of_set:
  assumes fin: "finite A" and nonempty: "A  B  {}"
  shows "cond_pmf (pmf_of_set A) B = pmf_of_set (A  B)" (is "?lhs = ?rhs")
proof(rule pmf_eqI)
  from nonempty have A: "A  {}" by auto
  show "pmf ?lhs x = pmf ?rhs x" for x
    by(subst pmf_cond; clarsimp simp add: fin A nonempty measure_pmf_of_set split: split_indicator)
qed

lemma pair_pmf_of_set:
  assumes A: "finite A" "A  {}"
    and B: "finite B" "B  {}"
  shows "pair_pmf (pmf_of_set A) (pmf_of_set B) = pmf_of_set (A × B)"
  by(rule pmf_eqI)(clarsimp simp add: pmf_pair assms split: split_indicator)

lemma emeasure_cond_pmf:
  fixes p A
  defines "q  cond_pmf p A"
  assumes "set_pmf p  A  {}"
  shows "emeasure (measure_pmf q) B = emeasure (measure_pmf p) (A  B) / emeasure (measure_pmf p) A"
proof -
  note [transfer_rule] = cond_pmf.transfer[OF assms(2), folded q_def]
  interpret pmf_as_measure .
  show ?thesis by transfer simp
qed

lemma measure_cond_pmf:
  "measure (measure_pmf (cond_pmf p A)) B = measure (measure_pmf p) (A  B) / measure (measure_pmf p) A"
  if "set_pmf p  A  {}"
  using emeasure_cond_pmf[OF that, of B] that 
  by(auto simp add: measure_pmf.emeasure_eq_measure measure_pmf_posI divide_ennreal)

lemma emeasure_measure_pmf_zero_iff: "emeasure (measure_pmf p) s = 0  set_pmf p  s = {}" (is "?lhs = ?rhs")
proof -
  have "?lhs  (AE x in measure_pmf p. x  s)"
    by(subst AE_iff_measurable)(auto)
  also have " = ?rhs" by(auto simp add: AE_measure_pmf_iff)
  finally show ?thesis .
qed

subsection ‹Subprobability mass functions›

lemma ord_spmf_return_spmf1: "ord_spmf R (return_spmf x) p  lossless_spmf p  (yset_spmf p. R x y)"
by(auto simp add: rel_pmf_return_pmf1 ord_option.simps in_set_spmf lossless_iff_set_pmf_None Ball_def) (metis option.exhaust)

lemma ord_spmf_conv:
  "ord_spmf R = rel_spmf R OO ord_spmf (=)"
apply(subst pmf.rel_compp[symmetric])
apply(rule arg_cong[where f="rel_pmf"])  
apply(rule ext)+
apply(auto elim!: ord_option.cases option.rel_cases intro: option.rel_intros)
done

lemma ord_spmf_expand:
  "NO_MATCH (=) R  ord_spmf R = rel_spmf R OO ord_spmf (=)"
by(rule ord_spmf_conv)

lemma ord_spmf_eqD_measure: "ord_spmf (=) p q  measure (measure_spmf p) A  measure (measure_spmf q) A"
by(drule ord_spmf_eqD_measure_spmf)(simp add: le_measure measure_spmf.emeasure_eq_measure)

lemma ord_spmf_measureD:
  assumes "ord_spmf R p q"
  shows "measure (measure_spmf p) A  measure (measure_spmf q) {y. xA. R x y}"
    (is "?lhs  ?rhs")
proof -
  from assms obtain p' where *: "rel_spmf R p p'" and **: "ord_spmf (=) p' q"
    by(auto simp add: ord_spmf_expand)
  have "?lhs  measure (measure_spmf p') {y. xA. R x y}" using * by(rule rel_spmf_measureD)
  also have "  ?rhs" using ** by(rule ord_spmf_eqD_measure)
  finally show ?thesis .
qed

lemma ord_spmf_bind_pmfI1:
  "(x. x  set_pmf p  ord_spmf R (f x) q)  ord_spmf R (bind_pmf p f) q"
  apply(rewrite at "ord_spmf _ _ " bind_return_pmf[symmetric, where f="λ_ :: unit. q"])
  apply(rule rel_pmf_bindI[where R="λx y. x  set_pmf p"])
  apply(simp_all add: rel_pmf_return_pmf2)
  done
  
lemma ord_spmf_bind_spmfI1:
  "(x. x  set_spmf p  ord_spmf R (f x) q)  ord_spmf R (bind_spmf p f) q"
unfolding bind_spmf_def by(rule ord_spmf_bind_pmfI1)(auto split: option.split simp add: in_set_spmf)

lemma spmf_of_set_empty: "spmf_of_set {} = return_pmf None"
by(simp add: spmf_of_set_def)

lemma rel_spmf_of_setI:
  assumes card: "X. X  A  card B * card X  card A * card {yB. xX. R x y}"
  and eq: "(finite A  A  {})  (finite B  B  {})"
  shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
using eq by(clarsimp simp add: spmf_of_set_def card rel_pmf_of_setI simp del: spmf_of_pmf_pmf_of_set cong: conj_cong)

lemmas map_bind_spmf = map_spmf_bind_spmf

lemma nn_integral_measure_spmf_conv_measure_pmf:
  assumes [measurable]: "f  borel_measurable (count_space UNIV)"
  shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f  the)"
by(simp add: measure_spmf_def nn_integral_distr o_def)

lemma nn_integral_spmf_neq_infinity: "(+ x. spmf p x count_space UNIV)  "
using nn_integral_measure_spmf[where f="λ_. 1", of p, symmetric] by simp

lemma return_pmf_bind_option:
  "return_pmf (Option.bind x f) = bind_spmf (return_pmf x) (return_pmf  f)"
by(cases x) simp_all

lemma rel_spmf_pos_distr: "rel_spmf A OO rel_spmf B  rel_spmf (A OO B)"
unfolding option.rel_compp pmf.rel_compp ..

lemma rel_spmf_OO_trans [trans]:
  " rel_spmf R p q; rel_spmf S q r   rel_spmf (R OO S) p r"
by(rule rel_spmf_pos_distr[THEN predicate2D]) auto

lemma map_spmf_eq_map_spmf_iff: "map_spmf f p = map_spmf g q  rel_spmf (λx y. f x = g y) p q"
by(simp add: spmf_rel_eq[symmetric] spmf_rel_map)

lemma map_spmf_eq_map_spmfI: "rel_spmf (λx y. f x = g y) p q  map_spmf f p = map_spmf g q"
by(simp add: map_spmf_eq_map_spmf_iff)

lemma spmf_rel_mono_strong:
  "rel_spmf A f g; x y.  x  set_spmf f; y  set_spmf g; A x y   B x y   rel_spmf B f g"
apply(erule pmf.rel_mono_strong)
apply(erule option.rel_mono_strong)
by(clarsimp simp add: in_set_spmf)

lemma set_spmf_eq_empty: "set_spmf p = {}  p = return_pmf None"
by auto (metis restrict_spmf_empty restrict_spmf_trivial)


lemma measure_pair_spmf_times:
  "measure (measure_spmf (pair_spmf p q)) (A × B) = measure (measure_spmf p) A * measure (measure_spmf q) B"
proof -
  have "emeasure (measure_spmf (pair_spmf p q)) (A × B) = (+ x. ennreal (spmf (pair_spmf p q) x) * indicator (A × B) x count_space UNIV)"
    by(simp add: nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  also have " = (+ x. (+ y. (ennreal (spmf p x) * indicator A x) * (ennreal (spmf q y) * indicator B y) count_space UNIV) count_space UNIV)"
    by(subst nn_integral_fst_count_space[symmetric])(auto intro!: nn_integral_cong split: split_indicator simp add: ennreal_mult)
  also have " = (+ x. ennreal (spmf p x) * indicator A x * emeasure (measure_spmf q) B count_space UNIV)"
    by(simp add: nn_integral_cmult nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  also have " = emeasure (measure_spmf p) A * emeasure (measure_spmf q) B"
    by(simp add: nn_integral_multc)(simp add: nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  finally show ?thesis by(simp add: measure_spmf.emeasure_eq_measure ennreal_mult[symmetric])
qed

lemma lossless_spmfD_set_spmf_nonempty: "lossless_spmf p  set_spmf p  {}"
using set_pmf_not_empty[of p] by(auto simp add: set_spmf_def bind_UNION lossless_iff_set_pmf_None)

lemma set_spmf_return_pmf: "set_spmf (return_pmf x) = set_option x"
by(cases x) simp_all

lemma bind_spmf_pmf_assoc: "bind_spmf (bind_pmf p f) g = bind_pmf p (λx. bind_spmf (f x) g)"
by(simp add: bind_spmf_def bind_assoc_pmf)

lemma bind_spmf_of_set:  " finite A; A  {}   bind_spmf (spmf_of_set A) f = bind_pmf (pmf_of_set A) f"
by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)

lemma bind_spmf_map_pmf:
  "bind_spmf (map_pmf f p) g = bind_pmf p (λx. bind_spmf (return_pmf (f x)) g)"
by(simp add: map_pmf_def bind_spmf_def bind_assoc_pmf)

lemma rel_spmf_eqI [simp]: "rel_spmf (=) x x"
by(simp add: option.rel_eq)

lemma set_spmf_map_pmf: "set_spmf (map_pmf f p) = (xset_pmf p. set_option (f x))" (* Move up *)
by(simp add: set_spmf_def bind_UNION)

lemma ord_spmf_return_spmf [simp]: "ord_spmf (=) (return_spmf x) p  p = return_spmf x"
proof -
  have "p = return_spmf x  ord_spmf (=) (return_spmf x) p" by simp
  thus ?thesis
    by (metis (no_types) ord_option_eq_simps(2) rel_pmf_return_pmf1 rel_pmf_return_pmf2 spmf.leq_antisym)
qed

declare
  set_bind_spmf [simp]
  set_spmf_return_pmf [simp]

lemma bind_spmf_pmf_commute:
  "bind_spmf p (λx. bind_pmf q (f x)) = bind_pmf q (λy. bind_spmf p (λx. f x y))"
unfolding bind_spmf_def 
by(subst bind_commute_pmf)(auto intro: bind_pmf_cong[OF refl] split: option.split)

lemma return_pmf_map_option_conv_bind:
  "return_pmf (map_option f x) = bind_spmf (return_pmf x) (return_spmf  f)"
by(cases x) simp_all

lemma lossless_return_pmf_iff [simp]: "lossless_spmf (return_pmf x)  x  None"
by(cases x) simp_all

lemma lossless_map_pmf: "lossless_spmf (map_pmf f p)  (x  set_pmf p. f x  None)"
using image_iff by(fastforce simp add: lossless_iff_set_pmf_None)

lemma bind_pmf_spmf_assoc:
  "g None = return_pmf None
   bind_pmf (bind_spmf p f) g = bind_spmf p (λx. bind_pmf (f x) g)"
by(auto simp add: bind_spmf_def bind_assoc_pmf bind_return_pmf fun_eq_iff intro!: arg_cong2[where f=bind_pmf] split: option.split)

abbreviation pred_spmf :: "('a  bool)  'a spmf  bool"
where "pred_spmf P  pred_pmf (pred_option P)"

lemma pred_spmf_def: "pred_spmf P p  (xset_spmf p. P x)"
by(auto simp add: pred_pmf_def pred_option_def set_spmf_def)

lemma spmf_pred_mono_strong:
  " pred_spmf P p; a.  a  set_spmf p; P a   P' a   pred_spmf P' p"
by(simp add: pred_spmf_def)

lemma spmf_Domainp_rel: "Domainp (rel_spmf R) = pred_spmf (Domainp R)"
by(simp add: pmf.Domainp_rel option.Domainp_rel)

lemma rel_spmf_restrict_relpI [intro?]:
  " rel_spmf R p q; pred_spmf P p; pred_spmf Q q   rel_spmf (R  P  Q) p q"
by(erule spmf_rel_mono_strong)(simp add: pred_spmf_def)

lemma rel_spmf_restrict_relpE [elim?]:
  assumes "rel_spmf (R  P  Q) x y"
  obtains "rel_spmf R x y" "pred_spmf P x" "pred_spmf Q y"
proof
  show "rel_spmf R x y" using assms by(auto elim!: spmf_rel_mono_strong)
  have "pred_spmf (Domainp (R  P  Q)) x" using assms by(fold spmf_Domainp_rel) blast
  then show "pred_spmf P x" by(rule spmf_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_spmf (Domainp (R  P  Q)¯¯) y" using assms
    by(fold spmf_Domainp_rel)(auto simp only: spmf_rel_conversep Domainp_conversep)
  then show "pred_spmf Q y" by(rule spmf_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_spmf_restrict_relp_iff:
  "rel_spmf (R  P  Q) x y  rel_spmf R x y  pred_spmf P x  pred_spmf Q y"
by(blast intro: rel_spmf_restrict_relpI elim: rel_spmf_restrict_relpE)

lemma spmf_pred_map: "pred_spmf P (map_spmf f p) = pred_spmf (P  f) p"
by(simp)

lemma pred_spmf_bind [simp]: "pred_spmf P (bind_spmf p f) = pred_spmf (pred_spmf P  f) p"
by(simp add: pred_spmf_def bind_UNION)

lemma pred_spmf_return: "pred_spmf P (return_spmf x) = P x"
by simp

lemma pred_spmf_return_pmf_None: "pred_spmf P (return_pmf None)"
by simp

lemma pred_spmf_spmf_of_pmf [simp]: "pred_spmf P (spmf_of_pmf p) = pred_pmf P p"
unfolding pred_spmf_def by(simp add: pred_pmf_def)

lemma pred_spmf_of_set [simp]: "pred_spmf P (spmf_of_set A) = (finite A  Ball A P)"
by(auto simp add: pred_spmf_def set_spmf_of_set)

lemma pred_spmf_assert_spmf [simp]: "pred_spmf P (assert_spmf b) = (b  P ())"
by(cases b) simp_all

lemma pred_spmf_pair [simp]:
  "pred_spmf P (pair_spmf p q) = pred_spmf (λx. pred_spmf (P  Pair x) q) p"
by(simp add: pred_spmf_def)

lemma set_spmf_try [simp]:
  "set_spmf (try_spmf p q) = set_spmf p  (if lossless_spmf p then {} else set_spmf q)"
by(auto simp add: try_spmf_def set_spmf_bind_pmf in_set_spmf lossless_iff_set_pmf_None split: option.splits)(metis option.collapse)

lemma try_spmf_bind_out1:
  "(x. lossless_spmf (f x))  bind_spmf (TRY p ELSE q) f = TRY (bind_spmf p f) ELSE (bind_spmf q f)"
  apply(clarsimp simp add: bind_spmf_def try_spmf_def bind_assoc_pmf bind_return_pmf intro!: bind_pmf_cong[OF refl] split: option.split)
  apply(rewrite in " = _" bind_return_pmf'[symmetric])
  apply(rule bind_pmf_cong[OF refl])
  apply(clarsimp split: option.split simp add: lossless_iff_set_pmf_None)
  done

lemma pred_spmf_try [simp]:
  "pred_spmf P (try_spmf p q) = (pred_spmf P p  (¬ lossless_spmf p  pred_spmf P q))"
by(auto simp add: pred_spmf_def)

lemma pred_spmf_cond [simp]:
  "pred_spmf P (cond_spmf p A) = pred_spmf (λx. x  A  P x) p"
by(auto simp add: pred_spmf_def)

lemma spmf_rel_map_restrict_relp: 
  shows spmf_rel_map_restrict_relp1: "rel_spmf (R  P  Q) (map_spmf f p) = rel_spmf (R  f  P  f  Q) p"
  and spmf_rel_map_restrict_relp2: "rel_spmf (R  P  Q) p (map_spmf g q) = rel_spmf ((λx. R x  g)  P  Q  g) p q"
by(simp_all add: spmf_rel_map restrict_relp_def)

lemma pred_spmf_conj: "pred_spmf (λx. P x  Q x) = (λx. pred_spmf P x  pred_spmf Q x)"
by simp

lemma spmf_of_pmf_parametric [transfer_rule]: 
  includes lifting_syntax shows
  "(rel_pmf A ===> rel_spmf A) spmf_of_pmf spmf_of_pmf"
unfolding spmf_of_pmf_def[abs_def] by transfer_prover

lemma mono2mono_return_pmf[THEN spmf.mono2mono, simp, cont_intro]: (* Move to SPMF *)
  shows monotone_return_pmf: "monotone option_ord (ord_spmf (=)) return_pmf"
by(rule monotoneI)(auto simp add: flat_ord_def)

lemma mcont2mcont_return_pmf[THEN spmf.mcont2mcont, simp, cont_intro]:  (* Move to SPMF *)
  shows mcont_return_pmf: "mcont (flat_lub None) option_ord lub_spmf (ord_spmf (=)) return_pmf"
by(rule mcont_finite_chains[OF _ _ flat_interpretation[THEN ccpo] ccpo_spmf]) simp_all

lemma pred_spmf_top: (* Move up *)
  "pred_spmf (λ_. True) = (λ_. True)"
by(simp)

lemma rel_spmf_restrict_relpI' [intro?]:
  " rel_spmf (λx y. P x  Q y  R x y) p q; pred_spmf P p; pred_spmf Q q   rel_spmf (R  P  Q) p q"
by(erule spmf_rel_mono_strong)(simp add: pred_spmf_def)

lemma set_spmf_map_pmf_MATCH [simp]:
  assumes "NO_MATCH (map_option g) f"
  shows "set_spmf (map_pmf f p) = (xset_pmf p. set_option (f x))"
by(rule set_spmf_map_pmf)

lemma rel_spmf_bindI':
  " rel_spmf A p q; x y.  A x y; x  set_spmf p; y  set_spmf q   rel_spmf B (f x) (g y) 
   rel_spmf B (p  f) (q  g)"
apply(rule rel_spmf_bindI[where R="λx y. A x y  x  set_spmf p  y  set_spmf q"])
 apply(erule spmf_rel_mono_strong; simp)
apply simp
done

definition rel_witness_spmf :: "('a  'b  bool)  'a spmf × 'b spmf  ('a × 'b) spmf" where
  "rel_witness_spmf A = map_pmf rel_witness_option  rel_witness_pmf (rel_option A)"

lemma assumes "rel_spmf A p q"
  shows rel_witness_spmf1: "rel_spmf (λa (a', b). a = a'  A a' b) p (rel_witness_spmf A (p, q))"
    and rel_witness_spmf2: "rel_spmf (λ(a, b') b. b = b'  A a b') (rel_witness_spmf A (p, q)) q"
  by(auto simp add: pmf.rel_map rel_witness_spmf_def intro: pmf.rel_mono_strong[OF rel_witness_pmf1[OF assms]] rel_witness_option1 pmf.rel_mono_strong[OF rel_witness_pmf2[OF assms]] rel_witness_option2)

lemma weight_assert_spmf [simp]: "weight_spmf (assert_spmf b) = indicator {True} b"
  by(simp split: split_indicator)

definition enforce_spmf :: "('a  bool)  'a spmf  'a spmf" where
  "enforce_spmf P = map_pmf (enforce_option P)"

lemma enforce_spmf_parametric [transfer_rule]: includes lifting_syntax shows
  "((A ===> (=)) ===> rel_spmf A ===> rel_spmf A) enforce_spmf enforce_spmf"
  unfolding enforce_spmf_def by transfer_prover

lemma enforce_return_spmf [simp]:
  "enforce_spmf P (return_spmf x) = (if P x then return_spmf x else return_pmf None)"
  by(simp add: enforce_spmf_def)

lemma enforce_return_pmf_None [simp]:
  "enforce_spmf P (return_pmf None) = return_pmf None"
  by(simp add: enforce_spmf_def)

lemma enforce_map_spmf:
  "enforce_spmf P (map_spmf f p) = map_spmf f (enforce_spmf (P  f) p)"
  by(simp add: enforce_spmf_def pmf.map_comp o_def enforce_map_option)

lemma enforce_bind_spmf [simp]:
  "enforce_spmf P (bind_spmf p f) = bind_spmf p (enforce_spmf P  f)"
  by(auto simp add: enforce_spmf_def bind_spmf_def map_bind_pmf intro!: bind_pmf_cong split: option.split)

lemma set_enforce_spmf [simp]: "set_spmf (enforce_spmf P p) = {a  set_spmf p. P a}"
  by(auto simp add: enforce_spmf_def in_set_spmf)

lemma enforce_spmf_alt_def:
  "enforce_spmf P p = bind_spmf p (λa. bind_spmf (assert_spmf (P a)) (λ_ :: unit. return_spmf a))"
  by(auto simp add: enforce_spmf_def assert_spmf_def map_pmf_def bind_spmf_def bind_return_pmf intro!: bind_pmf_cong split: option.split)

lemma bind_enforce_spmf [simp]:
  "bind_spmf (enforce_spmf P p) f = bind_spmf p (λx. if P x then f x else return_pmf None)"
  by(auto simp add: enforce_spmf_alt_def assert_spmf_def intro!: bind_spmf_cong)

lemma weight_enforce_spmf:
  "weight_spmf (enforce_spmf P p) = weight_spmf p - measure (measure_spmf p) {x. ¬ P x}" (is "?lhs = ?rhs")
proof -
  have "?lhs = LINT x|measure_spmf p. indicator {x. P x} x"
    by(auto simp add: enforce_spmf_alt_def weight_bind_spmf o_def simp del: Bochner_Integration.integral_indicator intro!: Bochner_Integration.integral_cong split: split_indicator)
  also have " = ?rhs"
    by(subst measure_spmf.finite_measure_Diff[symmetric])(auto simp add: space_measure_spmf intro!: arg_cong2[where f=measure])
  finally show ?thesis .
qed

lemma lossless_enforce_spmf [simp]:
  "lossless_spmf (enforce_spmf P p)  lossless_spmf p  set_spmf p  {x. P x}"
  by(auto simp add: enforce_spmf_alt_def)

lemma enforce_spmf_top [simp]: "enforce_spmf  = id"
  by(simp add: enforce_spmf_def)

lemma enforce_spmf_K_True [simp]: "enforce_spmf (λ_. True) p = p"
  using enforce_spmf_top[THEN fun_cong, of p] by(simp add: top_fun_def)

lemma enforce_spmf_bot [simp]: "enforce_spmf  = (λ_. return_pmf None)"
  by(simp add: enforce_spmf_def fun_eq_iff)

lemma enforce_spmf_K_False [simp]: "enforce_spmf (λ_. False) p = return_pmf None"
  using enforce_spmf_bot[THEN fun_cong, of p] by(simp add: bot_fun_def)

lemma enforce_pred_id_spmf: "enforce_spmf P p = p" if "pred_spmf P p"
proof -
  have "enforce_spmf P p = map_pmf id p" using that
    by(auto simp add: enforce_spmf_def enforce_pred_id_option simp del: map_pmf_id intro!: pmf.map_cong_pred[OF refl] elim!: pmf_pred_mono_strong)
  then show ?thesis by simp
qed

lemma map_the_spmf_of_pmf [simp]: "map_pmf the (spmf_of_pmf p) = p"
  by(simp add: spmf_of_pmf_def pmf.map_comp o_def)

lemma bind_bind_conv_pair_spmf:
  "bind_spmf p (λx. bind_spmf q (f x)) = bind_spmf (pair_spmf p q) (λ(x, y). f x y)"
  by(simp add: pair_spmf_alt_def)

lemma cond_spmf_spmf_of_set:
  "cond_spmf (spmf_of_set A) B = spmf_of_set (A  B)" if "finite A"
  by(rule spmf_eqI)(auto simp add: spmf_of_set measure_spmf_of_set that split: split_indicator)

lemma pair_spmf_of_set:
  "pair_spmf (spmf_of_set A) (spmf_of_set B) = spmf_of_set (A × B)"
  by(rule spmf_eqI)(clarsimp simp add: spmf_of_set card_cartesian_product split: split_indicator)

lemma emeasure_cond_spmf:
  "emeasure (measure_spmf (cond_spmf p A)) B = emeasure (measure_spmf p) (A  B) / emeasure (measure_spmf p) A"
  apply(clarsimp simp add: cond_spmf_def emeasure_measure_spmf_conv_measure_pmf emeasure_measure_pmf_zero_iff set_pmf_Int_Some split!: if_split)
   apply blast
  apply(subst (asm) emeasure_cond_pmf)
  by(auto simp add: set_pmf_Int_Some image_Int)

lemma measure_cond_spmf:
  "measure (measure_spmf (cond_spmf p A)) B = measure (measure_spmf p) (A  B) / measure (measure_spmf p) A"
  apply(clarsimp simp add: cond_spmf_def measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff set_pmf_Int_Some split!: if_split)
  apply(subst (asm) measure_cond_pmf)
  by(auto simp add: image_Int set_pmf_Int_Some)


lemma lossless_cond_spmf [simp]: "lossless_spmf (cond_spmf p A)  set_spmf p  A  {}"
  by(clarsimp simp add: cond_spmf_def lossless_iff_set_pmf_None set_pmf_Int_Some)

lemma measure_spmf_eq_density: "measure_spmf p = density (count_space UNIV) (spmf p)"
  by(rule measure_eqI)(simp_all add: emeasure_density nn_integral_spmf[symmetric] nn_integral_count_space_indicator)

lemma integral_measure_spmf:
  fixes f :: "'a  'b::{banach, second_countable_topology}"
  assumes A: "finite A"
  shows "(a. a  set_spmf M  f a  0  a  A)  (LINT x|measure_spmf M. f x) = (aA. spmf M a *R f a)"
  unfolding measure_spmf_eq_density
  apply (simp add: integral_density)
  apply (subst lebesgue_integral_count_space_finite_support)
  by (auto intro!: finite_subset[OF _ ‹finite A] sum.mono_neutral_left simp: spmf_eq_0_set_spmf)


lemma image_set_spmf_eq:
  "f ` set_spmf p = g ` set_spmf q" if "ASSUMPTION (map_spmf f p = map_spmf g q)"
  using that[unfolded ASSUMPTION_def, THEN arg_cong[where f=set_spmf]] by simp

lemma map_spmf_const: "map_spmf (λ_. x) p = scale_spmf (weight_spmf p) (return_spmf x)"
  by(simp add: map_spmf_conv_bind_spmf bind_spmf_const)

lemma cond_return_pmf [simp]: "cond_pmf (return_pmf x) A = return_pmf x" if "x  A"
  using that by(intro pmf_eqI)(auto simp add: pmf_cond split: split_indicator)

lemma cond_return_spmf [simp]: "cond_spmf (return_spmf x) A = (if x  A then return_spmf x else return_pmf None)"
  by(simp add: cond_spmf_def)

lemma measure_range_Some_eq_weight:
  "measure (measure_pmf p) (range Some) = weight_spmf p"
  by (simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf)

lemma restrict_spmf_eq_return_pmf_None [simp]:
  "restrict_spmf p A = return_pmf None  set_spmf p  A = {}"
  by(auto 4 3 simp add: restrict_spmf_def map_pmf_eq_return_pmf_iff bind_UNION in_set_spmf bind_eq_None_conv option.the_def dest: bspec split: if_split_asm option.split_asm)

definition mk_lossless :: "'a spmf  'a spmf" where
  "mk_lossless p = scale_spmf (inverse (weight_spmf p)) p"

lemma mk_lossless_idem [simp]: "mk_lossless (mk_lossless p) = mk_lossless p"
  by(simp add: mk_lossless_def weight_scale_spmf min_def max_def inverse_eq_divide) 

lemma mk_lossless_return [simp]: "mk_lossless (return_pmf x) = return_pmf x"
  by(cases x)(simp_all add: mk_lossless_def)

lemma mk_lossless_map [simp]: "mk_lossless (map_spmf f p) = map_spmf f (mk_lossless p)"
  by(simp add: mk_lossless_def map_scale_spmf)

lemma spmf_mk_lossless [simp]: "spmf (mk_lossless p) x = spmf p x / weight_spmf p"
  by(simp add: mk_lossless_def spmf_scale_spmf inverse_eq_divide max_def)

lemma set_spmf_mk_lossless [simp]: "set_spmf (mk_lossless p) = set_spmf p"
  by(simp add: mk_lossless_def set_scale_spmf measure_spmf_zero_iff zero_less_measure_iff)

lemma mk_lossless_lossless [simp]: "lossless_spmf p  mk_lossless p = p"
  by(simp add: mk_lossless_def lossless_weight_spmfD)

lemma mk_lossless_eq_return_pmf_None [simp]: "mk_lossless p = return_pmf None  p = return_pmf None"
proof -
  have aux: "weight_spmf p = 0  spmf p i = 0" for i
    by(rule antisym, rule order_trans[OF spmf_le_weight]) (auto intro!: order_trans[OF spmf_le_weight])

  have[simp]: " spmf (scale_spmf (inverse (weight_spmf p)) p) = spmf (return_pmf None)  spmf p i = 0" for i
    by(drule fun_cong[where x=i]) (auto simp add: aux spmf_scale_spmf max_def)

  show ?thesis by(auto simp add: mk_lossless_def intro: spmf_eqI)
qed

lemma return_pmf_None_eq_mk_lossless [simp]: "return_pmf None = mk_lossless p  p = return_pmf None"
  by(metis mk_lossless_eq_return_pmf_None)

lemma mk_lossless_spmf_of_set [simp]: "mk_lossless (spmf_of_set A) = spmf_of_set A"
  by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)

lemma weight_mk_lossless: "weight_spmf (mk_lossless p) = (if p = return_pmf None then 0 else 1)"
  by(simp add: mk_lossless_def weight_scale_spmf min_def max_def inverse_eq_divide weight_spmf_eq_0)

lemma mk_lossless_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_spmf A ===> rel_spmf A) mk_lossless mk_lossless"
  by(simp add: mk_lossless_def rel_fun_def rel_spmf_weightD rel_spmf_scaleI)

lemma rel_spmf_mk_losslessI:
  "rel_spmf A p q  rel_spmf A (mk_lossless p) (mk_lossless q)"
  by(rule mk_lossless_parametric[THEN rel_funD])

lemma rel_spmf_restrict_spmfI:
  "rel_spmf (λx y. (x  A  y  B  R x y)  x  A  y  B) p q
    rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
  by(auto simp add: restrict_spmf_def pmf.rel_map elim!: option.rel_cases pmf.rel_mono_strong)

lemma cond_spmf_alt: "cond_spmf p A = mk_lossless (restrict_spmf p A)"
proof(cases "set_spmf p  A = {}")
  case True
  then show ?thesis by(simp add: cond_spmf_def measure_spmf_zero_iff)
next
  case False
  show ?thesis
    by(rule spmf_eqI)(simp add: False cond_spmf_def pmf_cond set_pmf_Int_Some image_iff measure_measure_spmf_conv_measure_pmf[symmetric] spmf_scale_spmf max_def inverse_eq_divide)
qed

lemma cond_spmf_bind:
  "cond_spmf (bind_spmf p f) A = mk_lossless (p  (λx. f x  A))"
  by(simp add: cond_spmf_alt restrict_bind_spmf scale_bind_spmf)

lemma cond_spmf_UNIV [simp]: "cond_spmf p UNIV = mk_lossless p"
  by(clarsimp simp add: cond_spmf_alt)

lemma cond_pmf_singleton:
  "cond_pmf p A = return_pmf x" if "set_pmf p  A = {x}"
proof -
  have[simp]: "set_pmf p  A = {x}  x  A  measure_pmf.prob p A = pmf p x"
    by(auto simp add: measure_pmf_single[symmetric] AE_measure_pmf_iff intro!: measure_pmf.finite_measure_eq_AE)

  have "pmf (cond_pmf p A) i = pmf (return_pmf x) i" for i
    using that by(auto simp add: pmf_cond measure_pmf_zero_iff pmf_eq_0_set_pmf split: split_indicator)

  then show ?thesis by(rule pmf_eqI)
qed


definition cond_spmf_fst :: "('a × 'b) spmf  'a  'b spmf" where
  "cond_spmf_fst p a = map_spmf snd (cond_spmf p ({a} × UNIV))"

lemma cond_spmf_fst_return_spmf [simp]:
  "cond_spmf_fst (return_spmf (x, y)) x = return_spmf y"
  by(simp add: cond_spmf_fst_def)

lemma cond_spmf_fst_map_Pair [simp]: "cond_spmf_fst (map_spmf (Pair x) p) x = mk_lossless p"
  by(clarsimp simp add: cond_spmf_fst_def spmf.map_comp o_def)

lemma cond_spmf_fst_map_Pair' [simp]: "cond_spmf_fst (map_spmf (λy. (x, f y)) p) x = map_spmf f (mk_lossless p)"
  by(subst spmf.map_comp[where f="Pair x", symmetric, unfolded o_def]) simp

lemma cond_spmf_fst_eq_return_None [simp]: "cond_spmf_fst p x = return_pmf None  x  fst ` set_spmf p"
  by(auto 4 4 simp add: cond_spmf_fst_def map_pmf_eq_return_pmf_iff in_set_spmf[symmetric] dest: bspec[where x="Some _"] intro: ccontr rev_image_eqI)

lemma cond_spmf_fst_map_Pair1:
  "cond_spmf_fst (map_spmf (λx. (f x, g x)) p) (f x) = return_spmf (g (inv_into (set_spmf p) f (f x)))"
  if "x  set_spmf p" "inj_on f (set_spmf p)"
proof -
  let ?foo="λy. map_option (λx. (f x, g x)) -` Some ` ({f y} × UNIV)"
  have[simp]: "y  set_spmf p  f x = f y  set_pmf p  (?foo y)  {}" for y
    by(auto simp add: vimage_def image_def in_set_spmf)

  have[simp]: "y  set_spmf p  f x = f y   map_spmf snd (map_spmf (λx. (f x, g x)) (cond_pmf p (?foo y))) = return_spmf (g x)" for y
    using that by(subst cond_pmf_singleton[where x="Some x"]) (auto simp add: in_set_spmf elim: inj_onD)

  show ?thesis
    using that
    by(auto simp add: cond_spmf_fst_def cond_spmf_def)
      (erule notE, subst cond_map_pmf, simp_all)
qed

lemma lossless_cond_spmf_fst [simp]: "lossless_spmf (cond_spmf_fst p x)  x  fst ` set_spmf p"
  by(auto simp add: cond_spmf_fst_def intro: rev_image_eqI)

lemma cond_spmf_fst_inverse:
  "bind_spmf (map_spmf fst p) (λx. map_spmf (Pair x) (cond_spmf_fst p x)) = p"
  (is "?lhs = ?rhs")
proof(rule spmf_eqI)
  fix i :: "'a × 'b"
  have *: "({x} × UNIV  (Pair x  snd) -` {i}) = (if x = fst i then {i} else {})" for x by(cases i)auto
  have "spmf ?lhs i = LINT x|measure_spmf (map_spmf fst p). spmf (map_spmf (Pair x  snd) (cond_spmf p ({x} × UNIV))) i"
    by(auto simp add: spmf_bind spmf.map_comp[symmetric] cond_spmf_fst_def intro!: integral_cong_AE)
  also have " = LINT x|measure_spmf (map_spmf fst p). measure (measure_spmf (cond_spmf p ({x} × UNIV))) ((Pair x  snd) -` {i})"
    by(rule integral_cong_AE)(auto simp add: spmf_map)
  also have " = LINT x|measure_spmf (map_spmf fst p). measure (measure_spmf p) ({x} × UNIV  (Pair x  snd) -` {i}) /
       measure (measure_spmf p) ({x} × UNIV)"
    by(rule integral_cong_AE; clarsimp simp add: measure_cond_spmf)
  also have " = spmf (map_spmf fst p) (fst i) * spmf p i / measure (measure_spmf p) ({fst i} × UNIV)"
    by(simp add: * if_distrib[where f="measure (measure_spmf _)"] cong: if_cong)
      (subst integral_measure_spmf[where A="{fst i}"]; auto split: if_split_asm simp add: spmf_conv_measure_spmf)
  also have " = spmf p i"
    by(clarsimp simp add: spmf_map vimage_fst)(metis (no_types, lifting) Int_insert_left_if1 in_set_spmf_iff_spmf insertI1 insert_UNIV insert_absorb insert_not_empty measure_spmf_zero_iff mem_Sigma_iff prod.collapse)
  finally show "spmf ?lhs i = spmf ?rhs i" .
qed

subsubsection ‹Embedding of @{typ "'a option"} into @{typ "'a spmf"}

text ‹This theoretically follows from the embedding between @{typ "_ id"} into @{typ "_ prob"} and the isomorphism
  between @{typ "(_, _ prob) optionT"} and @{typ "_ spmf"}, but we would only get the monomorphic
  version via this connection. So we do it directly.
›

lemma bind_option_spmf_monad [simp]: "monad.bind_option (return_pmf None) x = bind_spmf (return_pmf x)"
by(cases x)(simp_all add: fun_eq_iff)

locale option_to_spmf begin

text ‹
  We have to get the embedding into the lifting package such that we can use the parametrisation of transfer rules.
›

definition the_pmf :: "'a pmf  'a" where "the_pmf p = (THE x. p = return_pmf x)"

lemma the_pmf_return [simp]: "the_pmf (return_pmf x) = x"
by(simp add: the_pmf_def)

lemma type_definition_option_spmf: "type_definition return_pmf the_pmf {x. y :: 'a option. x = return_pmf y}"
by unfold_locales(auto)

context begin
private setup_lifting type_definition_option_spmf
abbreviation cr_spmf_option where "cr_spmf_option  cr_option"
abbreviation pcr_spmf_option where "pcr_spmf_option  pcr_option"
lemmas Quotient_spmf_option = Quotient_option
  and cr_spmf_option_def = cr_option_def
  and pcr_spmf_option_bi_unique = option.bi_unique
  and Domainp_pcr_spmf_option = option.domain
  and Domainp_pcr_spmf_option_eq = option.domain_eq
  and Domainp_pcr_spmf_option_par = option.domain_par
  and Domainp_pcr_spmf_option_left_total = option.domain_par_left_total
  and pcr_spmf_option_left_unique = option.left_unique
  and pcr_spmf_option_cr_eq = option.pcr_cr_eq
  and pcr_spmf_option_return_pmf_transfer = option.rep_transfer
  and pcr_spmf_option_right_total = option.right_total
  and pcr_spmf_option_right_unique = option.right_unique
  and pcr_spmf_option_def = pcr_option_def
bundle spmf_option_lifting = [[Lifting.lifting_restore_internal "Misc_CryptHOL.option.lifting"]]
end


context includes lifting_syntax begin

lemma return_option_spmf_transfer [transfer_parametric return_spmf_parametric, transfer_rule]:
  "((=) ===> cr_spmf_option) return_spmf Some"
by(rule rel_funI)(simp add: cr_spmf_option_def)

lemma map_option_spmf_transfer [transfer_parametric map_spmf_parametric, transfer_rule]:
  "(((=) ===> (=)) ===> cr_spmf_option ===> cr_spmf_option) map_spmf map_option"
unfolding rel_fun_eq by(auto simp add: rel_fun_def cr_spmf_option_def)

lemma fail_option_spmf_transfer [transfer_parametric return_spmf_None_parametric, transfer_rule]:
  "cr_spmf_option (return_pmf None) None"
by(simp add: cr_spmf_option_def)

lemma bind_option_spmf_transfer [transfer_parametric bind_spmf_parametric, transfer_rule]:
  "(cr_spmf_option ===> ((=) ===> cr_spmf_option) ===> cr_spmf_option) bind_spmf Option.bind"
apply(clarsimp simp add: rel_fun_def cr_spmf_option_def)
subgoal for x f g by(cases x; simp)
done

lemma set_option_spmf_transfer [transfer_parametric set_spmf_parametric, transfer_rule]:
  "(cr_spmf_option ===> rel_set (=)) set_spmf set_option"
by(clarsimp simp add: rel_fun_def cr_spmf_option_def rel_set_eq)

lemma rel_option_spmf_transfer [transfer_parametric rel_spmf_parametric, transfer_rule]:
  "(((=) ===> (=) ===> (=)) ===> cr_spmf_option ===> cr_spmf_option ===> (=)) rel_spmf rel_option"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_spmf_option_def)

end

end

locale option_le_spmf begin

text ‹
  Embedding where only successful computations in the option monad are related to Dirac spmf.
›

definition cr_option_le_spmf :: "'a option  'a spmf  bool"
where "cr_option_le_spmf x p  ord_spmf (=) (return_pmf x) p"

context includes lifting_syntax begin

lemma return_option_le_spmf_transfer [transfer_rule]:
  "((=) ===> cr_option_le_spmf) (λx. x) return_pmf"
by(rule rel_funI)(simp add: cr_option_le_spmf_def ord_option_reflI)

lemma map_option_le_spmf_transfer [transfer_rule]:
  "(((=) ===> (=)) ===> cr_option_le_spmf ===> cr_option_le_spmf) map_option map_spmf"
unfolding rel_fun_eq
apply(clarsimp simp add: rel_fun_def cr_option_le_spmf_def rel_pmf_return_pmf1 ord_option_map1 ord_option_map2)
subgoal for f x p y by(cases x; simp add: ord_option_reflI)
done

lemma bind_option_le_spmf_transfer [transfer_rule]:
  "(cr_option_le_spmf ===> ((=) ===> cr_option_le_spmf) ===> cr_option_le_spmf) Option.bind bind_spmf"
apply(clarsimp simp add: rel_fun_def cr_option_le_spmf_def)
subgoal for x p f g by(cases x; auto 4 3 simp add: rel_pmf_return_pmf1 set_pmf_bind_spmf)
done

end

end

interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI)

lemma if_distrib_bind_spmf1 [if_distribs]:
  "bind_spmf (if b then x else y) f = (if b then bind_spmf x f else bind_spmf y f)"
by simp

lemma if_distrib_bind_spmf2 [if_distribs]:
  "bind_spmf x (λy. if b then f y else g y) = (if b then bind_spmf x f else bind_spmf x g)"
by simp

lemma rel_spmf_if_distrib [if_distribs]:
  "rel_spmf R (if b then x else y) (if b then x' else y') 
  (b  rel_spmf R x x')  (¬ b  rel_spmf R y y')"
by(simp)

lemma if_distrib_map_spmf [if_distribs]:
  "map_spmf f (if b then p else q) = (if b then map_spmf f p else map_spmf f q)"
by simp

lemma if_distrib_restrict_spmf1 [if_distribs]:
  "restrict_spmf (if b then p else q) A = (if b then restrict_spmf p A else restrict_spmf q A)"
by simp

end

Theory Set_Applicative

theory Set_Applicative imports
  Applicative_Lifting.Applicative_Set
begin

subsection ‹Applicative instance for @{typ "'a set"}

lemma ap_set_conv_bind: "ap_set f x = Set.bind f (λf. Set.bind x (λx. {f x}))"
by(auto simp add: ap_set_def bind_UNION)

context includes applicative_syntax begin

lemma in_ap_setI: " f'  f; x'  x   f' x'  f  x"
by(auto simp add: ap_set_def)

lemma in_ap_setE [elim!]:
  " x  f  y; f' y'.  x = f' y'; f'  f; y'  y   thesis   thesis"
by(auto simp add: ap_set_def)

lemma in_ap_pure_set [iff]: "x  {f}  y  (y'y. x = f y')"
unfolding ap_set_def by auto

end

end

Theory SPMF_Applicative

theory SPMF_Applicative imports
  Applicative_Lifting.Applicative_PMF
  Set_Applicative
  "HOL-Probability.SPMF"
begin

declare eq_on_def [simp del]

subsection ‹Applicative instance for @{typ "'a spmf"}

abbreviation (input) pure_spmf :: "'a  'a spmf"
where "pure_spmf  return_spmf"

definition ap_spmf :: "('a  'b) spmf  'a spmf  'b spmf"
where "ap_spmf f x = map_spmf (λ(f, x). f x) (pair_spmf f x)"

lemma ap_spmf_conv_bind: "ap_spmf f x = bind_spmf f (λf. bind_spmf x (λx. return_spmf (f x)))"
by(simp add: ap_spmf_def map_spmf_conv_bind_spmf pair_spmf_alt_def)

adhoc_overloading Applicative.ap ap_spmf

context includes applicative_syntax begin

lemma ap_spmf_id: "pure_spmf (λx. x)  x = x"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 spmf.map_comp o_def)

lemma ap_spmf_comp: "pure_spmf (∘)  u  v  w = u  (v  w)"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 pair_map_spmf1 pair_map_spmf2 spmf.map_comp o_def split_def pair_pair_spmf)

lemma ap_spmf_homo: "pure_spmf f  pure_spmf x = pure_spmf (f x)"
by(simp add: ap_spmf_def pair_spmf_return_spmf1)

lemma ap_spmf_interchange: "u  pure_spmf x = pure_spmf (λf. f x)  u"
by(simp add: ap_spmf_def pair_spmf_return_spmf1 pair_spmf_return_spmf2 spmf.map_comp o_def)

lemma ap_spmf_C: "return_spmf (λf x y. f y x)  f  x  y = f  y  x"
apply(simp add: ap_spmf_def pair_map_spmf1 spmf.map_comp pair_spmf_return_spmf1 pair_pair_spmf o_def split_def)
apply(subst (2) pair_commute_spmf)
apply(simp add: pair_map_spmf2 spmf.map_comp o_def split_def)
done

applicative spmf (C)
for
  pure: pure_spmf
  ap: ap_spmf
by(rule ap_spmf_id ap_spmf_comp[unfolded o_def[abs_def]] ap_spmf_homo ap_spmf_interchange ap_spmf_C)+

lemma set_ap_spmf [simp]: "set_spmf (p  q) = set_spmf p  set_spmf q"
by(auto simp add: ap_spmf_def ap_set_def)

lemma bind_ap_spmf: "bind_spmf (p  x) f = bind_spmf p (λp. x  (λx. f (p x)))"
by(simp add: ap_spmf_conv_bind)

lemma bind_pmf_ap_return_spmf [simp]: "bind_pmf (ap_spmf (return_spmf f) p) g = bind_pmf p (g  map_option f)"
by(auto simp add: ap_spmf_conv_bind bind_spmf_def bind_return_pmf bind_assoc_pmf intro: bind_pmf_cong split: option.split)

lemma map_spmf_conv_ap [applicative_unfold]: "map_spmf f p = return_spmf f  p"
by(simp add: map_spmf_conv_bind_spmf ap_spmf_conv_bind)

end

end

Theory List_Bits

(* Title: List_Bits.thy
  Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Exclusive or on lists›

theory List_Bits imports Misc_CryptHOL begin

definition xor :: "'a  'a  'a :: {uminus,inf,sup}" (infixr "" 67)
where "x  y = inf (sup x y) (- (inf x y))"

lemma xor_bool_def [iff]: fixes x y :: bool shows "x  y  x  y"
by(auto simp add: xor_def)

lemma xor_commute:
  fixes x y :: "'a :: {semilattice_sup,semilattice_inf,uminus}"
  shows "x  y = y  x"
by(simp add: xor_def sup.commute inf.commute)

lemma xor_assoc:
  fixes x y :: "'a :: boolean_algebra"
  shows "(x  y)  z = x  (y  z)"
by(simp add: xor_def inf_sup_aci inf_sup_distrib1 inf_sup_distrib2)

lemma xor_left_commute:
  fixes x y :: "'a :: boolean_algebra"
  shows "x  (y  z) = y  (x  z)"
by (metis xor_assoc xor_commute)

lemma [simp]:
  fixes x :: "'a :: boolean_algebra"
  shows xor_bot: "x  bot = x"
  and bot_xor: "bot  x = x"
  and xor_top: "x  top = - x"
  and top_xor: "top  x = - x"
by(simp_all add: xor_def)

lemma xor_inverse [simp]:
  fixes x :: "'a :: boolean_algebra"
  shows "x  x = bot"
by(simp add: xor_def)

lemma xor_left_inverse [simp]:
  fixes x :: "'a :: boolean_algebra"
  shows "x  x  y = y"
by(metis xor_left_commute xor_inverse xor_bot)

lemmas xor_ac = xor_assoc xor_commute xor_left_commute


definition xor_list :: "'a :: {uminus,inf,sup} list  'a list  'a list"  (infixr "[⊕]" 67)
where "xor_list xs ys = map (case_prod (⊕)) (zip xs ys)"

lemma xor_list_unfold:
  "xs [⊕] ys = (case xs of []  [] | x # xs'  (case ys of []  [] | y # ys'  x  y # xs' [⊕] ys'))"
by(simp add: xor_list_def split: list.split)

lemma xor_list_commute: fixes xs ys :: "'a :: {semilattice_sup,semilattice_inf,uminus} list"
  shows "xs [⊕] ys = ys [⊕] xs"
unfolding xor_list_def by(subst zip_commute)(auto simp add: split_def xor_commute)

lemma xor_list_assoc [simp]: 
  fixes xs ys :: "'a :: boolean_algebra list"
  shows "(xs [⊕] ys) [⊕] zs = xs [⊕] (ys [⊕] zs)"
unfolding xor_list_def zip_map1 zip_map2
apply(subst (2) zip_commute)
apply(subst zip_left_commute)
apply(subst (2) zip_commute)
apply(auto simp add: zip_map2 split_def xor_assoc)
done

lemma xor_list_left_commute:
  fixes xs ys zs :: "'a :: boolean_algebra list"
  shows "xs [⊕] (ys [⊕] zs) = ys [⊕] (xs [⊕] zs)"
by(metis xor_list_assoc xor_list_commute)

lemmas xor_list_ac = xor_list_assoc xor_list_commute xor_list_left_commute

lemma xor_list_inverse [simp]: 
  fixes xs :: "'a :: boolean_algebra list"
  shows "xs [⊕] xs = replicate (length xs) bot"
by(simp add: xor_list_def zip_same_conv_map o_def map_replicate_const)

lemma xor_replicate_bot_right [simp]:
  fixes xs :: "'a :: boolean_algebra list"
  shows " length xs  n; x = bot   xs [⊕] replicate n x = xs"
by(simp add: xor_list_def zip_replicate2 o_def)

lemma xor_replicate_bot_left [simp]:
  fixes xs :: "'a :: boolean_algebra list"
  shows " length xs  n; x = bot   replicate n x [⊕] xs = xs"
by(simp add: xor_list_commute)

lemma xor_list_left_inverse [simp]:
  fixes xs :: "'a :: boolean_algebra list"
  shows "length ys  length xs  xs [⊕] (xs [⊕] ys) = ys"
by(subst xor_list_assoc[symmetric])(simp)

lemma length_xor_list [simp]: "length (xor_list xs ys) = min (length xs) (length ys)"
by(simp add: xor_list_def)

lemma inj_on_xor_list_nlists [simp]:
  fixes xs :: "'a :: boolean_algebra list"
  shows "n  length xs  inj_on (xor_list xs) (nlists UNIV n)"
apply(clarsimp simp add: inj_on_def in_nlists_UNIV)
using xor_list_left_inverse by fastforce

lemma one_time_pad:
  fixes xs :: "_ :: boolean_algebra list"
  shows "length xs  n  map_spmf (xor_list xs) (spmf_of_set (nlists UNIV n)) = spmf_of_set (nlists UNIV n)"
by(auto 4 3 simp add: in_nlists_UNIV intro: xor_list_left_inverse[symmetric] rev_image_eqI intro!: arg_cong[where f=spmf_of_set])

end

Theory Environment_Functor

theory Environment_Functor imports
  Applicative_Lifting.Applicative_Environment
begin

subsection ‹The environment functor›

type_synonym ('i, 'a) envir = "'i  'a"

lemma const_apply [simp]: "const x i = x"
by(simp add: const_def)

context includes applicative_syntax begin

lemma ap_envir_apply [simp]: "(f  x) i = f i (x i)"
by(simp add: apf_def)

definition all_envir :: "('i, bool) envir  bool"
where "all_envir p  (x. p x)"

lemma all_envirI [Pure.intro!, intro!]: "(x. p x)  all_envir p"
by(simp add: all_envir_def)

lemma all_envirE [Pure.elim 2, elim]: "all_envir p  (p x  thesis)  thesis"
by(simp add: all_envir_def)

lemma all_envirD: "all_envir p  p x"
by(simp add: all_envir_def)


definition pred_envir :: "('a  bool)  ('i, 'a) envir  bool"
where "pred_envir p f = all_envir (const p  f)"

lemma pred_envir_conv: "pred_envir p f  (x. p (f x))"
by(auto simp add: pred_envir_def)

lemma pred_envirI [Pure.intro!, intro!]: "(x. p (f x))  pred_envir p f"
by(auto simp add: pred_envir_def)

lemma pred_envirD: "pred_envir p f  p (f x)"
by(auto simp add: pred_envir_def)

lemma pred_envirE [Pure.elim 2, elim]: "pred_envir p f  (p (f x)  thesis)  thesis"
by(simp add: pred_envir_conv)

lemma pred_envir_mono: " pred_envir p f; x. p (f x)  q (g x)   pred_envir q g"
by blast

definition rel_envir :: "('a  'b  bool)  ('i, 'a) envir  ('i, 'b) envir  bool"
where "rel_envir p f g  all_envir (const p  f  g)"

lemma rel_envir_conv: "rel_envir p f g  (x. p (f x) (g x))"
by(auto simp add: rel_envir_def)

lemma rel_envir_conv_rel_fun: "rel_envir = rel_fun (=)"
by(simp add: rel_envir_conv rel_fun_def fun_eq_iff)

lemma rel_envirI [Pure.intro!, intro!]: "(x. p (f x) (g x))  rel_envir p f g"
by(auto simp add: rel_envir_def)

lemma rel_envirD: "rel_envir p f g  p (f x) (g x)"
by(auto simp add: rel_envir_def)

lemma rel_envirE [Pure.elim 2, elim]: "rel_envir p f g  (p (f x) (g x)  thesis)  thesis"
by(simp add: rel_envir_conv)

lemma rel_envir_mono: " rel_envir p f g; x. p (f x) (g x)  q (f' x) (g' x)   rel_envir q f' g'"
by blast

lemma rel_envir_mono1: " pred_envir p f; x. p (f x)  q (f' x) (g' x)   rel_envir q f' g'"
by blast

lemma pred_envir_mono2: " rel_envir p f g; x. p (f x) (g x)  q (f' x)   pred_envir q f'"
by blast

end

end

Theory Partial_Function_Set

(* Title: Partial_Function_Set.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory Partial_Function_Set imports Main begin

subsection ‹Setup for partial_function› for sets›

lemma (in complete_lattice) lattice_partial_function_definition:
  "partial_function_definitions (≤) Sup"
by(unfold_locales)(auto intro: Sup_upper Sup_least)

interpretation set: partial_function_definitions "(⊆)" Union
by(rule lattice_partial_function_definition)

lemma fun_lub_Sup: "fun_lub Sup = (Sup :: _  _ :: complete_lattice)"
by(fastforce simp add: fun_lub_def fun_eq_iff Sup_fun_def intro: Sup_eqI SUP_upper SUP_least)

lemma set_admissible: "set.admissible (λf :: 'a  'b set. x y. y  f x  P x y)"
by(rule ccpo.admissibleI)(auto simp add: fun_lub_Sup)

abbreviation "mono_set  monotone (fun_ord (⊆)) (⊆)"

lemma fixp_induct_set_scott:
  fixes F :: "'c  'c"
  and U :: "'c  'b  'a set"
  and C :: "('b  'a set)  'c"
  and P :: "'b  'a  bool"
  and x and y
  assumes mono: "x. mono_set (λf. U (F (C f)) x)"
  and eq: "f  C (ccpo.fixp (fun_lub Sup) (fun_ord (≤)) (λf. U (F (C f))))"
  and inverse2: "f. U (C f) = f"
  and step: "f x y.  x y. y  U f x  P x y; y  U (F f) x   P x y"
  and enforce_variable_ordering: "x = x"
  and elem: "y  U f x"
  shows "P x y"
using step elem set.fixp_induct_uc[of U F C, OF mono eq inverse2 set_admissible, of P]
by blast


lemma fixp_Sup_le:
  defines "le  ((≤) :: _ :: complete_lattice  _)"
  shows "ccpo.fixp Sup le = ccpo_class.fixp"
proof -
  have "class.ccpo Sup le (<)" unfolding le_def by unfold_locales
  thus ?thesis
    by(simp add: ccpo.fixp_def fixp_def ccpo.iterates_def iterates_def ccpo.iteratesp_def iteratesp_def fun_eq_iff le_def)
qed

lemma fun_ord_le: "fun_ord (≤) = (≤)"
by(auto simp add: fun_ord_def fun_eq_iff le_fun_def)

lemma monotone_le_le: "monotone (≤) (≤) = mono"
by(simp add: monotone_def[abs_def] mono_def[abs_def])

lemma fixp_induct_set:
  fixes F :: "'c  'c"
  and U :: "'c  'b  'a set"
  and C :: "('b  'a set)  'c"
  and P :: "'b  'a  bool"
  and x and y
  assumes mono: "x. mono_set (λf. U (F (C f)) x)"
  and eq: "f  C (ccpo.fixp (fun_lub Sup) (fun_ord (≤)) (λf. U (F (C f))))"
  and inverse2: "f. U (C f) = f"

  and step: "f' x y.  x. U f' x = U f' x; y  U (F (C (inf (U f) (λx. {y. P x y})))) x   P x y"
    ― ‹partial\_function requires a quantifier over f', so let's have a fake one›
  and elem: "y  U f x"
  shows "P x y"
proof -
  from mono
  have mono': "mono (λf. U (F (C f)))"
    by(simp add: fun_ord_le monotone_le_le mono_def le_fun_def)
  hence eq': "f  C (lfp (λf. U (F (C f))))"
    using eq unfolding fun_ord_le fun_lub_Sup fixp_Sup_le by(simp add: lfp_eq_fixp)

  let ?f = "C (lfp (λf. U (F (C f))))"
  have step': "x y.  y  U (F (C (inf (U ?f) (λx. {y. P x y})))) x   P x y"
    unfolding eq'[symmetric] by(rule step[OF refl])

  let ?P = "λx. {y. P x y}"
  from mono' have "lfp (λf. U (F (C f)))  ?P"
    by(rule lfp_induct)(auto intro!: le_funI step' simp add: inverse2)
  with elem show ?thesis
    by(subst (asm) eq')(auto simp add: inverse2 le_fun_def)
qed

declaration Partial_Function.init "set" @{term set.fixp_fun}
  @{term set.mono_body} @{thm set.fixp_rule_uc} @{thm set.fixp_induct_uc}
  (SOME @{thm fixp_induct_set})

lemma [partial_function_mono]:
  shows insert_mono: "mono_set A  mono_set (λf. insert x (A f))"
  and UNION_mono: "mono_set B; y. mono_set (λf. C y f)  mono_set (λf. yB f. C y f)"
  and set_bind_mono: "mono_set B; y. mono_set (λf. C y f)  mono_set (λf. Set.bind (B f) (λy. C y f))"
  and Un_mono: " mono_set A; mono_set B   mono_set (λf. A f  B f)"
  and Int_mono: " mono_set A; mono_set B   mono_set (λf. A f  B f)"
  and Diff_mono1: "mono_set A  mono_set (λf. A f - X)"
  and image_mono: "mono_set A  mono_set (λf. g ` A f)"
  and vimage_mono: "mono_set A  mono_set (λf. g -` A f)"
unfolding bind_UNION by(fast intro!: monotoneI dest: monotoneD)+

partial_function (set) test :: "'a list  nat  bool  int set"
where
  "test xs i j = insert 4 (test [] 0 j  test [] 1 True  test [] 2 False - {5}  uminus ` test [undefined] 0 True  uminus -` test [] 1 False)"

interpretation coset: partial_function_definitions "(⊇)" Inter
by(rule complete_lattice.lattice_partial_function_definition[OF dual_complete_lattice])

lemma fun_lub_Inf: "fun_lub Inf = (Inf :: _  _ :: complete_lattice)"
by(auto simp add: fun_lub_def fun_eq_iff Inf_fun_def intro: Inf_eqI INF_lower INF_greatest)

lemma fun_ord_ge: "fun_ord (≥) = (≥)"
by(auto simp add: fun_ord_def fun_eq_iff le_fun_def)

lemma coset_admissible: "coset.admissible (λf :: 'a  'b set. x y. P x y  y  f x)"
by(rule ccpo.admissibleI)(auto simp add: fun_lub_Inf)

abbreviation "mono_coset  monotone (fun_ord (⊇)) (⊇)"

lemma gfp_eq_fixp:
  fixes f :: "'a :: complete_lattice  'a"
  assumes f: "monotone (≥) (≥) f"
  shows "gfp f = ccpo.fixp Inf (≥) f"
proof (rule antisym)
  from f have f': "mono f" by(simp add: mono_def monotone_def)

  interpret ccpo Inf "(≥)" "mk_less (≥) :: 'a  _"
    by(rule ccpo)(rule complete_lattice.lattice_partial_function_definition[OF dual_complete_lattice])
  show "ccpo.fixp Inf (≥) f  gfp f"
    by(rule gfp_upperbound)(subst fixp_unfold[OF f], rule order_refl)

  show "gfp f  ccpo.fixp Inf (≥) f"
    by(rule fixp_lowerbound[OF f])(subst gfp_unfold[OF f'], rule order_refl)
qed

lemma fixp_coinduct_set:
  fixes F :: "'c  'c"
  and U :: "'c  'b  'a set"
  and C :: "('b  'a set)  'c"
  and P :: "'b  'a  bool"
  and x and y
  assumes mono: "x. mono_coset (λf. U (F (C f)) x)"
  and eq: "f  C (ccpo.fixp (fun_lub Inter) (fun_ord (≥)) (λf. U (F (C f))))"
  and inverse2: "f. U (C f) = f"

  and step: "f' x y.  x. U f' x = U f' x; ¬ P x y   y  U (F (C (sup (λx. {y. ¬ P x y}) (U f)))) x"
    ― ‹partial\_function requires a quantifier over f', so let's have a fake one›
  and elem: "y  U f x"
  shows "P x y"
using elem
proof(rule contrapos_np)
  have mono': "monotone (≥) (≥) (λf. U (F (C f)))"
    and mono'': "mono (λf. U (F (C f)))"
    using mono by(simp_all add: monotone_def fun_ord_def le_fun_def mono_def)
  hence eq': "U f = gfp (λf. U (F (C f)))"
    by(subst eq)(simp add: fun_lub_Inf fun_ord_ge gfp_eq_fixp inverse2)

  let ?P = "λx. {y. ¬ P x y}"
  have "?P  gfp (λf. U (F (C f)))"
    using mono'' by(rule coinduct)(auto intro!:  le_funI dest: step[OF refl] simp add: eq')
  moreover
  assume "¬ P x y"
  ultimately show "y  U f x" by(auto simp add: le_fun_def eq')
qed

declaration Partial_Function.init "coset" @{term coset.fixp_fun}
  @{term coset.mono_body} @{thm coset.fixp_rule_uc} @{thm coset.fixp_induct_uc}
  (SOME @{thm fixp_coinduct_set})

abbreviation "mono_set'  monotone (fun_ord (⊇)) (⊇)"

lemma [partial_function_mono]:
  shows insert_mono': "mono_set' A  mono_set' (λf. insert x (A f))"
  and UNION_mono': "mono_set' B; y. mono_set' (λf. C y f)  mono_set' (λf. yB f. C y f)"
  and set_bind_mono': "mono_set' B; y. mono_set' (λf. C y f)  mono_set' (λf. Set.bind (B f) (λy. C y f))"
  and Un_mono': " mono_set' A; mono_set' B   mono_set' (λf. A f  B f)"
  and Int_mono': " mono_set' A; mono_set' B   mono_set' (λf. A f  B f)"
unfolding bind_UNION by(fast intro!: monotoneI dest: monotoneD)+

context begin
private partial_function (coset) test2 :: "nat  nat set"
where "test2 x = insert x (test2 (Suc x))"

private lemma test2_coinduct:
  assumes "P x y"
  and *: "x y. P x y  y = x  (P (Suc x) y  y  test2 (Suc x))"
  shows "y  test2 x"
using P x y
apply(rule contrapos_pp)
apply(erule test2.raw_induct[rotated])
apply(simp add: *)
done

end

end

Theory Negligible

(* Title: Negligible.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Negligibility›

theory Negligible imports
  Complex_Main
  Landau_Symbols.Landau_More
begin

named_theorems negligible_intros

definition negligible :: "(nat  real)  bool" (* TODO: generalise types? *)
where "negligible f  (c>0. f  o(λx. inverse (x powr c)))"

lemma negligibleI [intro?]:
  "(c. c > 0  f  o(λx. inverse (x powr c)))  negligible f"
unfolding negligible_def by(simp)

lemma negligibleD:
  " negligible f; c > 0   f  o(λx. inverse (x powr c))"
unfolding negligible_def by(simp)

lemma negligibleD_real:
  assumes "negligible f"
  shows "f  o(λx. inverse (x powr c))"
proof -
  let ?c = "max 1 c"
  have "f  o(λx. inverse (x powr ?c))" using assms by(rule negligibleD) simp
  also have "(λx. x powr c)  O(λx. real x powr max 1 c)"
    by(rule bigoI[where c=1])(auto simp add: eventually_at_top_linorder intro!: exI[where x=1] powr_mono)
  then have "(λx. inverse (real x powr max 1 c))  O(λx. inverse (x powr c))"
    by(auto simp add: eventually_at_top_linorder exI[where x=1] intro: landau_o.big.inverse)
  finally show ?thesis .
qed

lemma negligible_mono: " negligible g; f  O(g)   negligible f"
by(rule negligibleI)(drule (1) negligibleD; erule (1) landau_o.big_small_trans)

lemma negligible_le: " negligible g; η. ¦f η¦  g η   negligible f"
by(erule negligible_mono)(force intro: order_trans intro!: eventually_sequentiallyI landau_o.big_mono)

lemma negligible_K0 [negligible_intros, simp, intro!]: "negligible (λ_. 0)"
by(rule negligibleI) simp

lemma negligible_0 [negligible_intros, simp, intro!]: "negligible 0"
by(simp add: zero_fun_def)

lemma negligible_const_iff [simp]: "negligible (λ_. c :: real)  c = 0"
by(auto simp add: negligible_def const_smallo_inverse_powr filterlim_real_sequentially dest!: spec[where x=1])

lemma not_negligible_1: "¬ negligible (λ_. 1 :: real)"
by simp

lemma negligible_plus [negligible_intros]:
  " negligible f; negligible g   negligible (λη. f η + g η)"
by(auto intro!: negligibleI dest!: negligibleD intro: sum_in_smallo)

lemma negligible_uminus [simp]: "negligible (λη. - f η)  negligible f"
by(simp add: negligible_def)

lemma negligible_uminusI [negligible_intros]: "negligible f  negligible (λη. - f η)"
by simp

lemma negligible_minus [negligible_intros]:
  " negligible f; negligible g   negligible (λη. f η - g η)"
by(auto simp add: uminus_add_conv_diff[symmetric] negligible_plus simp del: uminus_add_conv_diff)

lemma negligible_cmult: "negligible (λη. c * f η)  negligible f  c = 0"
by(auto intro!: negligibleI dest!: negligibleD)

lemma negligible_cmultI [negligible_intros]:
  "(c  0  negligible f)  negligible (λη. c * f η)"
by(auto simp add: negligible_cmult)

lemma negligible_multc: "negligible (λη. f η * c)  negligible f  c = 0"
by(subst mult.commute)(simp add: negligible_cmult)

lemma negligible_multcI [negligible_intros]:
  "(c  0  negligible f)  negligible (λη. f η * c)"
by(auto simp add: negligible_multc)

lemma negligible_times [negligible_intros]:
  assumes f: "negligible f"
  and g: "negligible g"
  shows "negligible (λη. f η * g η :: real)"
proof
  fix c :: real
  assume "0 < c"
  hence "0 < c / 2" by simp
  from negligibleD[OF f this] negligibleD[OF g this]
  have "(λη. f η * g η)  o(λx. inverse (x powr (c / 2)) * inverse (x powr (c / 2)))"
    by(rule landau_o.small_mult)
  also have " = o(λx. inverse (x powr c))"
    by(rule landau_o.small.cong)(auto simp add: inverse_mult_distrib[symmetric] powr_add[symmetric] eventually_at_top_linorder intro!: exI[where x=1] simp del: inverse_mult_distrib)
  finally show "(λη. f η * g η)  " .
qed

lemma negligible_power [negligible_intros]:
  assumes "negligible f"
  and "n > 0"
  shows "negligible (λη. f η ^ n :: real)"
using n > 0
proof(induct n)
  case (Suc n)
  thus ?case using ‹negligible f by(cases n)(simp_all add: negligible_times)
qed simp

lemma negligible_powr [negligible_intros]:
  assumes f: "negligible f"
  and p: "p > 0"         
  shows "negligible (λx. ¦f x¦ powr p :: real)"
proof
  fix c :: real
  let ?c = "c / p"
  assume c: "0 < c"
  with p have "0 < ?c" by simp
  with f have "f  o(λx. inverse (x powr ?c))" by(rule negligibleD)
  hence "(λx. ¦f x¦ powr p)  o(λx. ¦inverse (x powr ?c)¦ powr p)" using p by(rule smallo_powr)
  also have " = o(λx. inverse (x powr c))"
    apply(rule landau_o.small.cong) using p by(auto simp add: powr_powr)
  finally show "(λx. ¦f x¦ powr p)  " .
qed

lemma negligible_abs [simp]: "negligible (λx. ¦f x¦)  negligible f"
by(simp add: negligible_def)

lemma negligible_absI [negligible_intros]: "negligible f  negligible (λx. ¦f x¦)"
by(simp)

lemma negligible_powrI [negligible_intros]:
  assumes "0  k" "k < 1"
  shows "negligible (λx. k powr x)"
proof(cases "k = 0")
  case True
  thus ?thesis by simp
next
  case False
  show ?thesis
  proof
    fix c :: real
    assume "0 < c"
    then have "(λx. real x powr c)  o(λx. inverse k powr real x)" using assms False
      by(intro powr_fast_growth_tendsto)(simp_all add: one_less_inverse_iff filterlim_real_sequentially)
    then have "(λx. inverse (k powr - real x))  o(λx. inverse (real x powr c))" using assms
      by(intro landau_o.small.inverse)(auto simp add: False eventually_sequentially powr_minus intro: exI[where x=1])
    also have "(λx. inverse (k powr - real x)) = (λx. k powr real x)" by(simp add: powr_minus)
    finally show "  o(λx. inverse (x powr c))" .
  qed
qed

lemma negligible_powerI [negligible_intros]:
  fixes k :: real
  assumes "¦k¦ < 1"
  shows "negligible (λn. k ^ n)"
proof(cases "k = 0")
  case True
  show ?thesis using negligible_K0
    by(rule negligible_mono)(auto intro: exI[where x=1] simp add: True eventually_at_top_linorder)
next
  case False
  hence "0 < ¦k¦" by auto
  from assms have "negligible (λx. ¦k¦ powr real x)" using negligible_powrI[of "¦k¦"] by simp
  hence "negligible (λx. ¦k¦ ^ x)" using False
    by(elim negligible_mono)(simp add: powr_realpow)
  then show ?thesis by(simp add: power_abs[symmetric])
qed

lemma negligible_inverse_powerI [negligible_intros]: "¦k¦ > 1  negligible (λη. 1 / k ^ η)"
using negligible_powerI[of "1 / k"] by(simp add: power_one_over)

inductive polynomial :: "(nat  real)  bool"
  for f
where "f  O(λx. x powr n)  polynomial f"

lemma negligible_times_poly:
  assumes f: "negligible f"
  and g: "g  O(λx. x powr n)"
  shows "negligible (λx. f x * g x)"
proof
  fix c :: real
  assume c: "0 < c"
  from negligibleD_real[OF f] g
  have "(λx. f x * g x)  o(λx. inverse (x powr (c + n)) * x powr n)"
    by(rule landau_o.small_big_mult)
  also have " = o(λx. inverse (x powr c))"
    by(rule landau_o.small.cong)(auto simp add: powr_minus[symmetric] powr_add[symmetric] intro!: exI[where x=0])
  finally show "(λx. f x * g x)  o(λx. inverse (x powr c))" .
qed

lemma negligible_poly_times:
  " f  O(λx. x powr n); negligible g   negligible (λx. f x * g x)"
by(subst mult.commute)(rule negligible_times_poly)

lemma negligible_times_polynomial [negligible_intros]:
  " negligible f; polynomial g   negligible (λx. f x * g x)"
by(clarsimp simp add: polynomial.simps negligible_times_poly)

lemma negligible_polynomial_times [negligible_intros]:
  " polynomial f; negligible g   negligible (λx. f x * g x)"
by(clarsimp simp add: polynomial.simps negligible_poly_times)

lemma negligible_divide_poly1:
  " f  O(λx. x powr n); negligible (λη. 1 / g η)   negligible (λη. real (f η) / g η)"
by(drule (1) negligible_times_poly) simp

lemma negligible_divide_polynomial1 [negligible_intros]:
  " polynomial f; negligible (λη. 1 / g η)   negligible (λη. real (f η) / g η)"
by(clarsimp simp add: polynomial.simps negligible_divide_poly1)

end

Theory Resumption

(* Title: Resumption.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹The resumption-error monad›

theory Resumption
imports
  Misc_CryptHOL
  Partial_Function_Set
begin

codatatype (results: 'a, outputs: 'out, 'in) resumption
  = Done (result: "'a option")
  | Pause ("output": 'out) (resume: "'in  ('a, 'out, 'in) resumption")
where
  "resume (Done a) = (λinp. Done None)"

code_datatype Done Pause

primcorec bind_resumption :: 
  "('a, 'out, 'in) resumption
      ('a  ('b, 'out, 'in) resumption)  ('b, 'out, 'in) resumption"
where
  " is_Done x; result x  None  is_Done (f (the (result x)))   is_Done (bind_resumption x f)"
| "result (bind_resumption x f) = result x  result  f"
| "output (bind_resumption x f) = (if is_Done x then output (f (the (result x))) else output x)"
| "resume (bind_resumption x f) = (λinp. if is_Done x then resume (f (the (result x))) inp else bind_resumption (resume x inp) f)"

declare bind_resumption.sel [simp del]

adhoc_overloading Monad_Syntax.bind bind_resumption

lemma is_Done_bind_resumption [simp]:
  "is_Done (x  f)  is_Done x  (result x  None  is_Done (f (the (result x))))"
by(simp add: bind_resumption_def)

lemma result_bind_resumption [simp]:
  "is_Done (x  f)  result (x  f) = result x  result  f"
by(simp add: bind_resumption_def)

lemma output_bind_resumption [simp]:
  "¬ is_Done (x  f)  output (x  f) = (if is_Done x then output (f (the (result x))) else output x)"
by(simp add: bind_resumption_def)

lemma resume_bind_resumption [simp]:
  "¬ is_Done (x  f) 
  resume (x  f) = 
  (if is_Done x then resume (f (the (result x))) 
   else (λinp. resume x inp  f))"
by(auto simp add: bind_resumption_def)

definition DONE :: "'a  ('a, 'out, 'in) resumption"
where "DONE = Done  Some"

definition ABORT :: "('a, 'out, 'in) resumption"
where "ABORT = Done None"

lemma [simp]:
  shows is_Done_DONE: "is_Done (DONE a)"
  and is_Done_ABORT: "is_Done ABORT"
  and result_DONE: "result (DONE a) = Some a"
  and result_ABORT: "result ABORT = None"
  and DONE_inject: "DONE a = DONE b  a = b"
  and DONE_neq_ABORT: "DONE a  ABORT"
  and ABORT_neq_DONE: "ABORT  DONE a"
  and ABORT_eq_Done: "a. ABORT = Done a  a = None"
  and Done_eq_ABORT: "a. Done a = ABORT  a = None"
  and DONE_eq_Done: "b. DONE a = Done b  b = Some a"
  and Done_eq_DONE: "b. Done b = DONE a  b = Some a"
  and DONE_neq_Pause: "DONE a  Pause out c"
  and Pause_neq_DONE: "Pause out c  DONE a"
  and ABORT_neq_Pause: "ABORT  Pause out c"
  and Pause_neq_ABORT: "Pause out c  ABORT"
by(auto simp add: DONE_def ABORT_def)

lemma resume_ABORT [simp]:
  "resume (Done r) = (λinp. ABORT)"
by(simp add: ABORT_def)

declare resumption.sel(3)[simp del]

lemma results_DONE [simp]: "results (DONE x) = {x}"
by(simp add: DONE_def)

lemma results_ABORT [simp]: "results ABORT = {}"
by(simp add: ABORT_def)

lemma outputs_ABORT [simp]: "outputs ABORT = {}"
by(simp add: ABORT_def)

lemma outputs_DONE [simp]: "outputs (DONE x) = {}"
by(simp add: DONE_def)

lemma is_Done_cases [cases pred]:
  assumes "is_Done r"
  obtains (DONE) x where "r = DONE x" | (ABORT) "r = ABORT"
using assms by(cases r) auto

lemma not_is_Done_conv_Pause: "¬ is_Done r  (out c. r = Pause out c)"
by(cases r) auto

lemma Done_bind [code]:
  "Done a  f = (case a of None  Done None | Some a  f a)"
by(rule resumption.expand)(auto split: option.split)

lemma DONE_bind [simp]:
  "DONE a  f = f a"
by(simp add: DONE_def Done_bind)

lemma bind_resumption_Pause [simp, code]: fixes cont shows
  "Pause out cont  f
  = Pause out (λinp. cont inp  f)"
by(rule resumption.expand)(simp_all)

lemma bind_DONE [simp]:
  "x  DONE = x"
by(coinduction arbitrary: x)(auto simp add: split_beta o_def)

lemma bind_bind_resumption:
  fixes r :: "('a, 'in, 'out) resumption" 
  shows "(r  f)  g = do { x  r; f x  g }"
apply(coinduction arbitrary: r rule: resumption.coinduct_strong)
apply(auto simp add: split_beta bind_eq_Some_conv)
apply(case_tac [!] "result r")
apply simp_all
done

lemmas resumption_monad = DONE_bind bind_DONE bind_bind_resumption

lemma ABORT_bind [simp]: "ABORT  f = ABORT"
by(simp add: ABORT_def Done_bind)

lemma bind_resumption_is_Done: "is_Done f  f  g = (if result f = None then ABORT else g (the (result f)))"
by(rule resumption.expand) auto

lemma bind_resumption_eq_Done_iff [simp]:
  "f  g = Done x  (y. f = DONE y  g y = Done x)  f = ABORT  x = None"
by(cases f)(auto simp add: Done_bind split: option.split)

lemma bind_resumption_cong:
  assumes "x = y"
  and "z. z  results y  f z = g z"
  shows "x  f = y  g"
using assms(2) unfolding x = y
proof(coinduction arbitrary: y rule: resumption.coinduct_strong)
  case Eq_resumption thus ?case
    by(auto intro: resumption.set_sel simp add: is_Done_def rel_fun_def)
      (fastforce del: exI intro!: exI intro: resumption.set_sel(2) simp add: is_Done_def)
qed

lemma results_bind_resumption: (* Move to Resumption *)
  "results (bind_resumption x f) = (aresults x. results (f a))"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  show "z  ?rhs" if "z  ?lhs" for z using that
  proof(induction r"x  f" arbitrary: x)
    case (Done z z' x)
    from Done(1) Done(2)[symmetric] show ?case by(auto)
  next
    case (Pause out c r z x)
    then show ?case
    proof(cases x)
      case (Done x')
      show ?thesis
      proof(cases x')
        case None
        with Done Pause(4) show ?thesis by(auto simp add: ABORT_def[symmetric])
      next
        case (Some x'')
        thus ?thesis using Pause(1,2,4) Done
          by(auto 4 3 simp add: DONE_def[unfolded o_def, symmetric, unfolded fun_eq_iff] dest: sym)
      qed
    qed(fastforce)
  qed
next
  fix z 
  assume "z  ?rhs"
  then obtain z' where z': "z'  results x"
    and z: "z  results (f z')" by blast
  from z' show "z  ?lhs"
  proof(induction z'z' x)
    case (Done r)
    then show ?case using z
      by(auto simp add: DONE_def[unfolded o_def, symmetric, unfolded fun_eq_iff])
  qed auto
qed

lemma outputs_bind_resumption [simp]:
  "outputs (bind_resumption r f) = outputs r  (xresults r. outputs (f x))"
  (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
  proof(induction r'"bind_resumption r f" arbitrary: r)
    case (Pause1 out c)
    thus ?case by(cases r)(auto simp add: Done_bind split: option.split_asm dest: sym)
  next
    case (Pause2 out c r' x)
    thus ?case by(cases r)(auto 4 3 simp add: Done_bind split: option.split_asm dest: sym)
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (left) "x  outputs r" | (right) a where "a  results r" "x  outputs (f a)" by auto
  then show "x  ?lhs"
  proof cases
    { case left  thus ?thesis by induction auto }
    { case right thus ?thesis by induction(auto simp add: Done_bind) }
  qed
qed

primrec ensure :: "bool  (unit, 'out, 'in) resumption"
where
  "ensure True = DONE ()" 
| "ensure False = ABORT"

lemma is_Done_map_resumption [simp]:
  "is_Done (map_resumption f1 f2 r)  is_Done r"
by(cases r) simp_all

lemma result_map_resumption [simp]: 
  "is_Done r  result (map_resumption f1 f2 r) = map_option f1 (result r)"
by(clarsimp simp add: is_Done_def)

lemma output_map_resumption [simp]:
  "¬ is_Done r  output (map_resumption f1 f2 r) = f2 (output r)"
by(cases r) simp_all

lemma resume_map_resumption [simp]:
  "¬ is_Done r
   resume (map_resumption f1 f2 r) = map_resumption f1 f2  resume r"
by(cases r) simp_all

lemma rel_resumption_is_DoneD: "rel_resumption A B r1 r2  is_Done r1  is_Done r2"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust]) simp_all

lemma rel_resumption_resultD1:
  " rel_resumption A B r1 r2; is_Done r1   rel_option A (result r1) (result r2)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust]) simp_all

lemma rel_resumption_resultD2:
  " rel_resumption A B r1 r2; is_Done r2   rel_option A (result r1) (result r2)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust]) simp_all

lemma rel_resumption_outputD1:
  " rel_resumption A B r1 r2; ¬ is_Done r1   B (output r1) (output r2)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust]) simp_all

lemma rel_resumption_outputD2:
  " rel_resumption A B r1 r2; ¬ is_Done r2   B (output r1) (output r2)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust]) simp_all

lemma rel_resumption_resumeD1:
  " rel_resumption A B r1 r2; ¬ is_Done r1 
   rel_resumption A B (resume r1 inp) (resume r2 inp)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust])(auto dest: rel_funD)

lemma rel_resumption_resumeD2:
  " rel_resumption A B r1 r2; ¬ is_Done r2 
   rel_resumption A B (resume r1 inp) (resume r2 inp)"
by(cases r1 r2 rule: resumption.exhaust[case_product resumption.exhaust])(auto dest: rel_funD)

lemma rel_resumption_coinduct
  [consumes 1, case_names Done Pause,
   case_conclusion Done is_Done result,
   case_conclusion Pause "output" resume,
   coinduct pred: rel_resumption]:
  assumes X: "X r1 r2"
  and Done: "r1 r2. X r1 r2  (is_Done r1  is_Done r2)  (is_Done r1  is_Done r2  rel_option A (result r1) (result r2))"
  and Pause: "r1 r2.  X r1 r2; ¬ is_Done r1; ¬ is_Done r2   B (output r1) (output r2)  (inp. X (resume r1 inp) (resume r2 inp))" 
  shows "rel_resumption A B r1 r2"
using X
apply(rule resumption.rel_coinduct)
apply(unfold rel_fun_def)
apply(rule conjI)
 apply(erule Done[THEN conjunct1])
apply(rule conjI)
 apply(erule Done[THEN conjunct2])
apply(rule impI)+
apply(drule (2) Pause)
apply blast
done

subsection ‹Setup for partial_function›

context includes lifting_syntax begin

coinductive resumption_ord :: "('a, 'out, 'in) resumption  ('a, 'out, 'in) resumption  bool"
where
  Done_Done: "flat_ord None a a'  resumption_ord (Done a) (Done a')"
| Done_Pause: "resumption_ord ABORT (Pause out c)"
| Pause_Pause: "((=) ===> resumption_ord) c c'  resumption_ord (Pause out c) (Pause out c')"

inductive_simps resumption_ord_simps [simp]:
  "resumption_ord (Pause out c) r"
  "resumption_ord r (Done a)"

lemma resumption_ord_is_DoneD:
  " resumption_ord r r'; is_Done r'   is_Done r"
by(cases r')(auto simp add: fun_ord_def)

lemma resumption_ord_resultD:
  " resumption_ord r r'; is_Done r'   flat_ord None (result r) (result r')"
by(cases r')(auto simp add: flat_ord_def)

lemma resumption_ord_outputD:
  " resumption_ord r r'; ¬ is_Done r   output r = output r'"
by(cases r) auto

lemma resumption_ord_resumeD:
  " resumption_ord r r'; ¬ is_Done r   ((=) ===> resumption_ord) (resume r) (resume r')"
by(cases r) auto

lemma resumption_ord_abort:
  " resumption_ord r r'; is_Done r; ¬ is_Done r'   result r = None"
by(auto elim: resumption_ord.cases)

lemma resumption_ord_coinduct [consumes 1, case_names Done Abort Pause, case_conclusion Pause "output" resume, coinduct pred: resumption_ord]:
  assumes "X r r'"
  and Done: "r r'.  X r r'; is_Done r'   is_Done r  flat_ord None (result r) (result r')"
  and Abort: "r r'.  X r r'; ¬ is_Done r'; is_Done r   result r = None"
  and Pause: "r r'.  X r r'; ¬ is_Done r; ¬ is_Done r'  
   output r = output r'  ((=) ===> (λr r'. X r r'  resumption_ord r r')) (resume r) (resume r')"
  shows "resumption_ord r r'"
using X r r'
proof coinduct
  case (resumption_ord r r')
  thus ?case
    by(cases r r' rule: resumption.exhaust[case_product resumption.exhaust])(auto dest: Done Pause Abort)
qed

end

lemma resumption_ord_ABORT [intro!, simp]: "resumption_ord ABORT r"
by(cases r)(simp_all add: flat_ord_def resumption_ord.Done_Pause)

lemma resumption_ord_ABORT2 [simp]: "resumption_ord r ABORT  r = ABORT"
by(simp add: ABORT_def flat_ord_def)

lemma resumption_ord_DONE1 [simp]: "resumption_ord (DONE x) r  r = DONE x"
by(cases r)(auto simp add: option_ord_Some1_iff DONE_def dest: resumption_ord_abort)

lemma resumption_ord_refl: "resumption_ord r r"
by(coinduction arbitrary: r)(auto simp add: flat_ord_def)

lemma resumption_ord_antisym:
  " resumption_ord r r'; resumption_ord r' r 
   r = r'"
proof(coinduction arbitrary: r r' rule: resumption.coinduct_strong)
  case (Eq_resumption r r')
  thus ?case
    by cases(auto simp add: flat_ord_def rel_fun_def)
qed

lemma resumption_ord_trans:
  " resumption_ord r r'; resumption_ord r' r'' 
   resumption_ord r r''"
proof(coinduction arbitrary: r r' r'')
  case (Done r r' r'')
  thus ?case by(auto 4 4 elim: resumption_ord.cases simp add: flat_ord_def)
next
  case (Abort r r' r'')
  thus ?case by(auto 4 4 elim: resumption_ord.cases simp add: flat_ord_def)
next
  case (Pause r r' r'')
  hence "resumption_ord r r'" "resumption_ord r' r''" by simp_all
  thus ?case using ¬ is_Done r ¬ is_Done r''
    by(cases)(auto simp add: rel_fun_def)
qed

primcorec resumption_lub :: "('a, 'out, 'in) resumption set  ('a, 'out, 'in) resumption"
where
  "r  R. is_Done r  is_Done (resumption_lub R)"
| "result (resumption_lub R) = flat_lub None (result ` R)"
| "output (resumption_lub R) = (THE out. out  output ` (R  {r. ¬ is_Done r}))"
| "resume (resumption_lub R) = (λinp. resumption_lub ((λc. c inp) ` resume ` (R  {r. ¬ is_Done r})))"

lemma is_Done_resumption_lub [simp]:
  "is_Done (resumption_lub R)  (r  R. is_Done r)"
by(simp add: resumption_lub_def)

lemma result_resumption_lub [simp]:
  "r  R. is_Done r  result (resumption_lub R) = flat_lub None (result ` R)"
by(simp add: resumption_lub_def)

lemma output_resumption_lub [simp]:
  "rR. ¬ is_Done r  output (resumption_lub R) = (THE out. out  output ` (R  {r. ¬ is_Done r}))"
by(simp add: resumption_lub_def)

lemma resume_resumption_lub [simp]:
  "rR. ¬ is_Done r
   resume (resumption_lub R) inp = 
     resumption_lub ((λc. c inp) ` resume ` (R  {r. ¬ is_Done r}))"
by(simp add: resumption_lub_def)

lemma resumption_lub_empty: "resumption_lub {} = ABORT"
by(subst resumption_lub.code)(simp add: flat_lub_def)

context
  fixes R state inp R'
  defines R'_def: "R'  (λc. c inp) ` resume ` (R  {r. ¬ is_Done r})"
  assumes chain: "Complete_Partial_Order.chain resumption_ord R"
begin

lemma resumption_ord_chain_resume: "Complete_Partial_Order.chain resumption_ord R'"
proof(rule chainI)
  fix r' r''
  assume "r'  R'"
    and "r''  R'"
  then obtain 𝗋' 𝗋'' 
    where r': "r' = resume 𝗋' inp" "𝗋'  R" "¬ is_Done 𝗋'"
    and r'': "r'' = resume 𝗋'' inp" "𝗋''  R" "¬ is_Done 𝗋''"
    by(auto simp add: R'_def)
  from chain 𝗋'  R 𝗋''  R
  have "resumption_ord 𝗋' 𝗋''  resumption_ord 𝗋'' 𝗋'"
    by(auto elim: chainE)
  with r' r''
  have "resumption_ord (resume 𝗋' inp) (resume 𝗋'' inp) 
        resumption_ord (resume 𝗋'' inp) (resume 𝗋' inp)"
    by(auto elim: resumption_ord.cases simp add: rel_fun_def)
  with r' r''
  show "resumption_ord r' r''  resumption_ord r'' r'" by auto
qed

end

lemma resumption_partial_function_definition:
  "partial_function_definitions resumption_ord resumption_lub"
proof
  show "resumption_ord r r" for r :: "('a, 'b, 'c) resumption" by(rule resumption_ord_refl)
  show "resumption_ord r r''" if "resumption_ord r r'" "resumption_ord r' r''"
    for r r' r'' :: "('a, 'b, 'c) resumption" using that by(rule resumption_ord_trans)
  show "r = r'" if "resumption_ord r r'" "resumption_ord r' r" for r r' :: "('a, 'b, 'c) resumption"
    using that by(rule resumption_ord_antisym)
next
  fix R and r :: "('a, 'b, 'c) resumption"
  assume "Complete_Partial_Order.chain resumption_ord R" "r  R"
  thus "resumption_ord r (resumption_lub R)"
  proof(coinduction arbitrary: r R)
    case (Done r R)
    note chain = ‹Complete_Partial_Order.chain resumption_ord R
      and r = r  R
    from ‹is_Done (resumption_lub R) have A: "r  R. is_Done r" by simp
    with r obtain a' where "r = Done a'" by(cases r) auto
    { fix r'
      assume "a'  None"
      hence "(THE x. x  result ` R  x  None) = a'"
        using r A r = Done a'
        by(auto 4 3 del: the_equality intro!: the_equality intro: rev_image_eqI elim: chainE[OF chain] simp add: flat_ord_def is_Done_def) 
    }
    with A r r = Done a' show ?case
      by(cases a')(auto simp add: flat_ord_def flat_lub_def)
  next
    case (Abort r R)
    hence chain: "Complete_Partial_Order.chain resumption_ord R" and "r  R" by simp_all
    from r  R ¬ is_Done (resumption_lub R) ‹is_Done r
    show ?case by(auto elim: chainE[OF chain] dest: resumption_ord_abort resumption_ord_is_DoneD)
  next
    case (Pause r R)
    hence chain: "Complete_Partial_Order.chain resumption_ord R"
      and r: "r  R" by simp_all
    have ?resume 
      using r ¬ is_Done r resumption_ord_chain_resume[OF chain]
      by(auto simp add: rel_fun_def bexI)
    moreover
    from r ¬ is_Done r have "output (resumption_lub R) = output r"
      by(auto 4 4 simp add: bexI del: the_equality intro!: the_equality elim: chainE[OF chain] dest: resumption_ord_outputD)
    ultimately show ?case by simp
  qed
next
  fix R and r :: "('a, 'b, 'c) resumption"
  assume "Complete_Partial_Order.chain resumption_ord R" "r'. r'  R  resumption_ord r' r"
  thus "resumption_ord (resumption_lub R) r"
  proof(coinduction arbitrary: R r)
    case (Done R r)
    hence chain: "Complete_Partial_Order.chain resumption_ord R"
      and ub: "r'R. resumption_ord r' r" by simp_all
    from ‹is_Done r ub have is_Done: "r'  R. is_Done r'"
      and ub': "r'. r'  result ` R  flat_ord None r' (result r)"
      by(auto dest: resumption_ord_is_DoneD resumption_ord_resultD)
    from is_Done have chain': "Complete_Partial_Order.chain (flat_ord None) (result ` R)"
      by(auto 5 2 intro!: chainI elim: chainE[OF chain] dest: resumption_ord_resultD)
    hence "flat_ord None (flat_lub None (result ` R)) (result r)"
      by(rule partial_function_definitions.lub_least[OF flat_interpretation])(rule ub')
    thus ?case using is_Done by simp
  next
    case (Abort R r)
    hence chain: "Complete_Partial_Order.chain resumption_ord R"
      and ub: "r'R. resumption_ord r' r" by simp_all
    from ¬ is_Done r ‹is_Done (resumption_lub R) ub
    show ?case by(auto simp add: flat_lub_def dest: resumption_ord_abort)
  next
    case (Pause R r)
    hence chain: "Complete_Partial_Order.chain resumption_ord R"
      and ub: "r'. r'R  resumption_ord r' r" by simp_all
    from ¬ is_Done (resumption_lub R) have exR: "r  R. ¬ is_Done r" by simp
    then obtain r' where r': "r'  R" "¬ is_Done r'" by auto
    with ub[of r'] have "output r = output r'" by(auto dest: resumption_ord_outputD)
    also have [symmetric]: "output (resumption_lub R) = output r'" using exR r'
      by(auto 4 4 elim: chainE[OF chain] dest: resumption_ord_outputD)
    finally have ?output ..
    moreover 
    { fix inp r''
      assume "r''  R" "¬ is_Done r''"
      with ub[of r'']
      have "resumption_ord (resume r'' inp) (resume r inp)"
        by(auto dest!: resumption_ord_resumeD simp add: rel_fun_def) }
    with exR resumption_ord_chain_resume[OF chain] r'
    have ?resume by(auto simp add: rel_fun_def)
    ultimately show ?case ..
  qed
qed

interpretation resumption:
  partial_function_definitions resumption_ord resumption_lub
  rewrites "resumption_lub {} = (ABORT :: ('a, 'b, 'c) resumption)"
by (rule resumption_partial_function_definition resumption_lub_empty)+

declaration Partial_Function.init "resumption" @{term resumption.fixp_fun}
  @{term resumption.mono_body} @{thm resumption.fixp_rule_uc} @{thm resumption.fixp_induct_uc} NONE›

abbreviation "mono_resumption  monotone (fun_ord resumption_ord) resumption_ord"

lemma mono_resumption_resume:
  assumes "mono_resumption B"
  shows "mono_resumption (λf. resume (B f) inp)"
proof
  fix f g :: "'a  ('b, 'c, 'd) resumption"
  assume fg: "fun_ord resumption_ord f g"
  hence "resumption_ord (B f) (B g)" by(rule monotoneD[OF assms])
  with resumption_ord_resumeD[OF this]
  show "resumption_ord (resume (B f) inp) (resume (B g) inp)"
    by(cases "is_Done (B f)")(auto simp add: rel_fun_def is_Done_def)
qed

lemma bind_resumption_mono [partial_function_mono]:
  assumes mf: "mono_resumption B"
  and mg: "y. mono_resumption (C y)"
  shows "mono_resumption (λf. do { y  B f; C y f })"
proof(rule monotoneI)
  fix f g :: "'a  ('b, 'c, 'd) resumption"
  assume *: "fun_ord resumption_ord f g"
  define f' where "f'  B f" define g' where "g'  B g"
  define h where "h  λx. C x f" define k where "k  λx. C x g"
  from mf[THEN monotoneD, OF *] mg[THEN monotoneD, OF *] f'_def g'_def h_def k_def
  have "resumption_ord f' g'" "x. resumption_ord (h x) (k x)" by auto
  thus "resumption_ord (f'  h) (g'  k)"
  proof(coinduction arbitrary: f' g' h k)
    case (Done f' g' h k)
    hence le: "resumption_ord f' g'"
      and mg: "y. resumption_ord (h y) (k y)" by simp_all
    from ‹is_Done (g'  k)
    have done_Bg: "is_Done g'" 
      and "result g'  None  is_Done (k (the (result g')))" by simp_all
    moreover
    have "is_Done f'" using le done_Bg by(rule resumption_ord_is_DoneD)
    moreover
    from le done_Bg have "flat_ord None (result f') (result g')"
      by(rule resumption_ord_resultD)
    hence "result f'  None  result g' = result f'"
      by(auto simp add: flat_ord_def)
    moreover
    have "resumption_ord (h (the (result f'))) (k (the (result f')))" by(rule mg)
    ultimately show ?case
      by(subst (1 2) result_bind_resumption)(auto dest: resumption_ord_is_DoneD resumption_ord_resultD simp add: flat_ord_def bind_eq_None_conv)
  next
    case (Abort f' g' h k)
    hence "resumption_ord (h (the (result f'))) (k (the (result f')))" by simp
    thus ?case using Abort
      by(cases "is_Done g'")(auto 4 4 simp add: bind_eq_None_conv flat_ord_def dest: resumption_ord_abort resumption_ord_resultD resumption_ord_is_DoneD)
  next
    case (Pause f' g' h k)
    hence ?output
      by(auto 4 4 dest: resumption_ord_outputD resumption_ord_is_DoneD resumption_ord_resultD resumption_ord_abort simp add: flat_ord_def)
    moreover have ?resume
    proof(cases "is_Done f'")
      case False
      with Pause show ?thesis
        by(auto simp add: rel_fun_def dest: resumption_ord_is_DoneD intro: resumption_ord_resumeD[THEN rel_funD] del: exI intro!: exI)
    next
      case True
      hence "is_Done g'" using Pause by(auto dest: resumption_ord_abort)
      thus ?thesis using True Pause resumption_ord_resultD[OF ‹resumption_ord f' g']
        by(auto del: rel_funI intro!: rel_funI simp add: bind_resumption_is_Done flat_ord_def intro: resumption_ord_resumeD[THEN rel_funD] exI[where x=f'] exI[where x=g'])
    qed
    ultimately show ?case ..
  qed
qed

lemma fixes f F
  defines "F  λresults r. case r of resumption.Done x  set_option x | resumption.Pause out c  input. results (c input)"
  shows results_conv_fixp: "results  ccpo.fixp (fun_lub Union) (fun_ord (⊆)) F" (is "_  ?fixp")
  and results_mono: "x. monotone (fun_ord (⊆)) (⊆) (λf. F f x)" (is "PROP ?mono")
proof(rule eq_reflection ext antisym subsetI)+
  show mono: "PROP ?mono" unfolding F_def by(tactic Partial_Function.mono_tac @{context} 1)
  fix x r
  show "?fixp r  results r"
    by(induction arbitrary: r rule: lfp.fixp_induct_uc[of "λx. x" F "λx. x", OF mono reflexive refl])
      (fastforce simp add: F_def split: resumption.split_asm)+

  assume "x  results r"
  thus "x  ?fixp r" by induct(subst lfp.mono_body_fixp[OF mono]; auto simp add: F_def)+
qed

lemma mcont_case_resumption:
  fixes f g
  defines "h  λr. if is_Done r then f (result r) else g (output r) (resume r) r"
  assumes mcont1: "mcont (flat_lub None) option_ord lub ord f"
  and mcont2: "out. mcont (fun_lub resumption_lub) (fun_ord resumption_ord) lub ord (λc. g out c (Pause out c))"
  and ccpo: "class.ccpo lub ord (mk_less ord)"
  and bot: "x. ord (f None) x"
  shows "mcont resumption_lub resumption_ord lub ord (λr. case r of Done x  f x | Pause out c  g out c r)"
    (is "mcont ?lub ?ord _ _ ?f")
proof(rule resumption.mcont_if_bot[OF ccpo bot, where bound=ABORT and f=h])
  show "?f x = (if ?ord x ABORT then f None else h x)" for x
    by(simp add: h_def split: resumption.split)
  show "ord (h x) (h y)" if "?ord x y" "¬ ?ord x ABORT" for x y using that
    by(cases x)(simp_all add: h_def mcont_monoD[OF mcont1] fun_ord_conv_rel_fun mcont_monoD[OF mcont2])
    
  fix Y :: "('a, 'b, 'c) resumption set"
  assume chain: "Complete_Partial_Order.chain ?ord Y"
    and Y: "Y  {}"
    and nbot: "x. x  Y  ¬ ?ord x ABORT"
  show "h (?lub Y) = lub (h ` Y)"
  proof(cases "x. DONE x  Y")
    case True
    then obtain x where x: "DONE x  Y" ..
    have is_Done: "is_Done r" if "r  Y" for r using chainD[OF chain that x]
      by(auto dest: resumption_ord_is_DoneD)
    from is_Done have chain': "Complete_Partial_Order.chain (flat_ord None) (result ` Y)"
      by(auto 5 2 intro!: chainI elim: chainE[OF chain] dest: resumption_ord_resultD)
    from is_Done have "is_Done (?lub Y)" "Y  {r. is_Done r} = Y" "Y  {r. ¬ is_Done r} = {}" by auto
    then show ?thesis using Y by(simp add: h_def mcont_contD[OF mcont1 chain'] image_image)
  next
    case False
    have is_Done: "¬ is_Done r" if "r  Y" for r using that False nbot
      by(auto elim!: is_Done_cases)
    from Y obtain out c where Pause: "Pause out c  Y"
      by(auto 5 2 dest: is_Done iff: not_is_Done_conv_Pause)
    
    have out: "(THE out. out  output ` (Y  {r. ¬ is_Done r})) = out" using Pause
      by(auto 4 3 intro: rev_image_eqI iff: not_is_Done_conv_Pause dest: chainD[OF chain])
    have "(λr. g (output r) (resume r) r) ` (Y  {r. ¬ is_Done r}) = (λr. g out (resume r) r) ` (Y  {r. ¬ is_Done r})"
      by(auto 4 3 simp add: not_is_Done_conv_Pause dest: chainD[OF chain Pause] intro: rev_image_eqI)
    moreover have "¬ is_Done (?lub Y)" using Y is_Done by(auto)
    moreover from is_Done have "Y  {r. is_Done r} = {}" "Y  {r. ¬ is_Done r} = Y" by auto
    moreover have "(λinp. resumption_lub ((λx. resume x inp) ` Y)) = fun_lub resumption_lub (resume ` Y)"
      by(auto simp add: fun_lub_def fun_eq_iff intro!: arg_cong[where f="resumption_lub"])
    moreover have "resumption_lub Y = Pause out (fun_lub resumption_lub (resume ` Y))"
      using Y is_Done out
      by(intro resumption.expand)(auto simp add: fun_lub_def fun_eq_iff image_image intro!: arg_cong[where f=resumption_lub])
    moreover have chain': "Complete_Partial_Order.chain resumption.le_fun (resume ` Y)" using chain
      by(rule chain_imageI)(auto dest!: is_Done simp add: not_is_Done_conv_Pause fun_ord_conv_rel_fun)
    moreover have "(λr. g out (resume r) (Pause out (resume r))) ` Y = (λr. g out (resume r) r) ` Y"
      by(intro image_cong[OF refl])(frule nbot; auto dest!: chainD[OF chain Pause] elim: resumption_ord.cases)
    ultimately show ?thesis using False out Y 
      by(simp add: h_def image_image mcont_contD[OF mcont2])
  qed
qed
    
lemma mcont2mcont_results[THEN mcont2mcont, cont_intro, simp]:
  shows mcont_results: "mcont resumption_lub resumption_ord Union (⊆) results"
apply(rule lfp.fixp_preserves_mcont1[OF results_mono results_conv_fixp])
apply(rule mcont_case_resumption)
apply(simp_all add: mcont_applyI)
done

lemma mono2mono_results[THEN lfp.mono2mono, cont_intro, simp]:
  shows monotone_results: "monotone resumption_ord (⊆) results"
using mcont_results by(rule mcont_mono)

lemma fixes f F
  defines "F  λoutputs xs. case xs of resumption.Done x  {} | resumption.Pause out c  insert out (input. outputs (c input))"
  shows outputs_conv_fixp: "outputs  ccpo.fixp (fun_lub Union) (fun_ord (⊆)) F" (is "_  ?fixp")
  and outputs_mono: "x. monotone (fun_ord (⊆)) (⊆) (λf. F f x)" (is "PROP ?mono")
proof(rule eq_reflection ext antisym subsetI)+
  show mono: "PROP ?mono" unfolding F_def by(tactic Partial_Function.mono_tac @{context} 1)
  show "?fixp r  outputs r" for r
    by(induct arbitrary: r rule: lfp.fixp_induct_uc[of "λx. x" F "λx. x", OF mono reflexive refl])(auto simp add: F_def split: resumption.split)
  show "x  ?fixp r" if "x  outputs r" for x r using that
    by induct(subst lfp.mono_body_fixp[OF mono]; auto simp add: F_def; fail)+
qed

lemma mcont2mcont_outputs[THEN lfp.mcont2mcont, cont_intro, simp]: 
  shows mcont_outputs: "mcont resumption_lub resumption_ord Union (⊆) outputs"
apply(rule lfp.fixp_preserves_mcont1[OF outputs_mono outputs_conv_fixp])
apply(auto intro: lfp.mcont2mcont intro!: mcont2mcont_insert mcont_SUP mcont_case_resumption)
done

lemma mono2mono_outputs[THEN lfp.mono2mono, cont_intro, simp]:
  shows monotone_outputs: "monotone resumption_ord (⊆) outputs"
using mcont_outputs by(rule mcont_mono)

lemma pred_resumption_antimono:
  assumes r: "pred_resumption A C r'"
  and le: "resumption_ord r r'"
  shows "pred_resumption A C r"
using r monotoneD[OF monotone_results le] monotoneD[OF monotone_outputs le]
by(auto simp add: pred_resumption_def)

subsection ‹Setup for lifting and transfer›

declare resumption.rel_eq [id_simps, relator_eq]
declare resumption.rel_mono [relator_mono]

lemma rel_resumption_OO [relator_distr]:
  "rel_resumption A B OO rel_resumption C D = rel_resumption (A OO C) (B OO D)" 
by(simp add: resumption.rel_compp)

lemma left_total_rel_resumption [transfer_rule]:
  " left_total R1; left_total R2   left_total (rel_resumption R1 R2)"
  by(simp only: left_total_alt_def resumption.rel_eq[symmetric] resumption.rel_conversep[symmetric] rel_resumption_OO resumption.rel_mono)

lemma left_unique_rel_resumption [transfer_rule]:
  " left_unique R1; left_unique R2   left_unique (rel_resumption R1 R2)"
  by(simp only: left_unique_alt_def resumption.rel_eq[symmetric] resumption.rel_conversep[symmetric] rel_resumption_OO resumption.rel_mono)

lemma right_total_rel_resumption [transfer_rule]:
  " right_total R1; right_total R2   right_total (rel_resumption R1 R2)"
  by(simp only: right_total_alt_def resumption.rel_eq[symmetric] resumption.rel_conversep[symmetric] rel_resumption_OO resumption.rel_mono)

lemma right_unique_rel_resumption [transfer_rule]:
  " right_unique R1; right_unique R2   right_unique (rel_resumption R1 R2)"
  by(simp only: right_unique_alt_def resumption.rel_eq[symmetric] resumption.rel_conversep[symmetric] rel_resumption_OO resumption.rel_mono)

lemma bi_total_rel_resumption [transfer_rule]:
  " bi_total A; bi_total B   bi_total (rel_resumption A B)"
unfolding bi_total_alt_def
by(blast intro: left_total_rel_resumption right_total_rel_resumption)

lemma bi_unique_rel_resumption [transfer_rule]:
  " bi_unique A; bi_unique B   bi_unique (rel_resumption A B)"
unfolding bi_unique_alt_def
by(blast intro: left_unique_rel_resumption right_unique_rel_resumption)

lemma Quotient_resumption [quot_map]:
  " Quotient R1 Abs1 Rep1 T1; Quotient R2 Abs2 Rep2 T2 
   Quotient (rel_resumption R1 R2) (map_resumption Abs1 Abs2) (map_resumption Rep1 Rep2) (rel_resumption T1 T2)"
  by(simp add: Quotient_alt_def5 resumption.rel_Grp[of UNIV _ UNIV _, symmetric, simplified] resumption.rel_compp resumption.rel_conversep[symmetric] resumption.rel_mono)

end

Theory Generat

(* Title: Generat.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Generative probabilistic values›

theory Generat imports 
  Misc_CryptHOL
begin

subsection ‹Single-step generative›

datatype (generat_pures: 'a, generat_outs: 'b, generat_conts: 'c) generat 
  = Pure (result: 'a)
  | IO ("output": 'b) (continuation: "'c")

datatype_compat generat

lemma IO_code_cong: "out = out'  IO out c = IO out' c" by simp
setup Code_Simp.map_ss (Simplifier.add_cong @{thm IO_code_cong})

lemma is_Pure_map_generat [simp]: "is_Pure (map_generat f g h x) = is_Pure x"
by(cases x) simp_all

lemma result_map_generat [simp]: "is_Pure x  result (map_generat f g h x) = f (result x)"
by(cases x) simp_all

lemma output_map_generat [simp]: "¬ is_Pure x  output (map_generat f g h x) = g (output x)"
by(cases x) simp_all

lemma continuation_map_generat [simp]: "¬ is_Pure x  continuation (map_generat f g h x) = h (continuation x)"
by(cases x) simp_all

lemma [simp]:
  shows map_generat_eq_Pure:
  "map_generat f g h generat = Pure x  (x'. generat = Pure x'  x = f x')"
  and Pure_eq_map_generat:
  "Pure x = map_generat f g h generat  (x'. generat = Pure x'  x = f x')"
by(cases generat; auto; fail)+

lemma [simp]:
  shows map_generat_eq_IO:
  "map_generat f g h generat = IO out c  (out' c'. generat = IO out' c'  out = g out'  c = h c')"
  and IO_eq_map_generat:
  "IO out c = map_generat f g h generat  (out' c'. generat = IO out' c'  out = g out'  c = h c')"
by(cases generat; auto; fail)+

lemma is_PureE [cases pred]:
  assumes "is_Pure generat"
  obtains (Pure) x where "generat = Pure x"
using assms by(auto simp add: is_Pure_def)

lemma not_is_PureE:
  assumes "¬ is_Pure generat"
  obtains (IO) out c where "generat = IO out c"
using assms by(cases generat) auto

lemma rel_generatI:
  " is_Pure x  is_Pure y;
      is_Pure x; is_Pure y   A (result x) (result y);
      ¬ is_Pure x; ¬ is_Pure y   Out (output x) (output y)  R (continuation x) (continuation y) 
   rel_generat A Out R x y"
by(cases x y rule: generat.exhaust[case_product generat.exhaust]) simp_all

lemma rel_generatD':
  "rel_generat A Out R x y
   (is_Pure x  is_Pure y)  
     (is_Pure x  is_Pure y  A (result x) (result y))  
     (¬ is_Pure x  ¬ is_Pure y  Out (output x) (output y)  R (continuation x) (continuation y))"
by(cases x y rule: generat.exhaust[case_product generat.exhaust]) simp_all

lemma rel_generatD:
  assumes "rel_generat A Out R x y"
  shows rel_generat_is_PureD: "is_Pure x  is_Pure y"
  and rel_generat_resultD: "is_Pure x  is_Pure y  A (result x) (result y)"
  and rel_generat_outputD: "¬ is_Pure x  ¬ is_Pure y  Out (output x) (output y)"
  and rel_generat_continuationD: "¬ is_Pure x  ¬ is_Pure y  R (continuation x) (continuation y)"
using rel_generatD'[OF assms] by simp_all

lemma rel_generat_mono:
  " rel_generat A B C x y; x y. A x y  A' x y; x y. B x y  B' x y; x y. C x y  C' x y 
   rel_generat A' B' C' x y"
using generat.rel_mono[of A A' B B' C C'] by(auto simp add: le_fun_def)

lemma rel_generat_mono' [mono]:
  " x y. A x y  A' x y; x y. B x y  B' x y; x y. C x y  C' x y 
   rel_generat A B C x y  rel_generat A' B' C' x y"
by(blast intro: rel_generat_mono)

lemma rel_generat_same:
  "rel_generat A B C r r  
  (x  generat_pures r. A x x) 
  (out  generat_outs r. B out out) 
  (c generat_conts r. C c c)"
by(cases r)(auto simp add: rel_fun_def)

lemma rel_generat_reflI:
  " y. y  generat_pures x  A y y; 
     out. out  generat_outs x  B out out;
     cont. cont  generat_conts x  C cont cont 
   rel_generat A B C x x"
by(cases x) auto

lemma reflp_rel_generat [simp]: "reflp (rel_generat A B C)  reflp A  reflp B  reflp C"
by(auto 4 3 intro!: reflpI rel_generatI dest: reflpD reflpD[where x="Pure _"] reflpD[where x="IO _ _"])

lemma transp_rel_generatI:
  assumes "transp A" "transp B" "transp C"
  shows "transp (rel_generat A B C)"
by(rule transpI)(auto 6 5 dest: rel_generatD' intro!: rel_generatI intro: assms[THEN transpD] simp add: rel_fun_def)

lemma rel_generat_inf:
  "inf (rel_generat A B C) (rel_generat A' B' C') = rel_generat (inf A A') (inf B B') (inf C C')"
  (is "?lhs = ?rhs")
proof(rule antisym)
  show "?lhs  ?rhs"
    by(auto elim!: generat.rel_cases simp add: rel_fun_def)
qed(auto elim: rel_generat_mono)

lemma rel_generat_Pure1: "rel_generat A B C (Pure x) = (λr. y. r = Pure y  A x y)"
by(rule ext)(case_tac r, simp_all)

lemma rel_generat_IO1: "rel_generat A B C (IO out c) = (λr. out' c'. r = IO out' c'  B out out'  C c c')"
by(rule ext)(case_tac r, simp_all)

lemma not_is_Pure_conv: "¬ is_Pure r  (out c. r = IO out c)"
by(cases r) auto

lemma finite_generat_outs [simp]: "finite (generat_outs generat)"
by(cases generat) auto

lemma countable_generat_outs [simp]: "countable (generat_outs generat)"
by(simp add: countable_finite)

lemma case_map_generat:
  "case_generat pure io (map_generat a b d r) = 
   case_generat (pure  a) (λout. io (b out)  d) r"
by(cases r) simp_all

lemma continuation_in_generat_conts:
  "¬ is_Pure r  continuation r  generat_conts r"
by(cases r) auto


fun dest_IO :: "('a, 'out, 'c) generat  ('out × 'c) option"
where
  "dest_IO (Pure _) = None"
| "dest_IO (IO out c) = Some (out, c)"

lemma dest_IO_eq_Some_iff [simp]: "dest_IO generat = Some (out, c)  generat = IO out c"
by(cases generat) simp_all

lemma dest_IO_eq_None_iff [simp]: "dest_IO generat = None  is_Pure generat"
by(cases generat) simp_all

lemma dest_IO_comp_Pure [simp]: "dest_IO  Pure = (λ_. None)"
by(simp add: fun_eq_iff)

lemma dom_dest_IO: "dom dest_IO = {x. ¬ is_Pure x}"
by(auto simp add: not_is_Pure_conv)


definition generat_lub :: "('a set  'b)  ('out set  'out')  ('cont set  'cont') 
   ('a, 'out, 'cont) generat set  ('b, 'out', 'cont') generat"
where
  "generat_lub lub1 lub2 lub3 A =
  (if xA. is_Pure x then Pure (lub1 (result ` (A  {f. is_Pure f})))
   else IO (lub2 (output ` (A  {f. ¬ is_Pure f}))) (lub3 (continuation ` (A  {f. ¬ is_Pure f}))))"

lemma is_Pure_generat_lub [simp]:
  "is_Pure (generat_lub lub1 lub2 lub3 A)  (xA. is_Pure x)"
by(simp add: generat_lub_def)

lemma result_generat_lub [simp]:
  "xA. is_Pure x  result (generat_lub lub1 lub2 lub3 A) = lub1 (result ` (A  {f. is_Pure f}))"
by(simp add: generat_lub_def)

lemma output_generat_lub: 
  "xA. ¬ is_Pure x  output (generat_lub lub1 lub2 lub3 A) = lub2 (output ` (A  {f. ¬ is_Pure f}))"
by(simp add: generat_lub_def)

lemma continuation_generat_lub:
  "xA. ¬ is_Pure x  continuation (generat_lub lub1 lub2 lub3 A) = lub3 (continuation ` (A  {f. ¬ is_Pure f}))"
by(simp add: generat_lub_def)

lemma generat_lub_map [simp]:
  "generat_lub lub1 lub2 lub3 (map_generat f g h ` A) = generat_lub (lub1  (`) f) (lub2  (`) g) (lub3  (`) h) A"
by(auto 4 3 simp add: generat_lub_def intro: arg_cong[where f=lub1] arg_cong[where f=lub2] arg_cong[where f=lub3] rev_image_eqI del: ext intro!: ext)

lemma map_generat_lub [simp]:
  "map_generat f g h (generat_lub lub1 lub2 lub3 A) = generat_lub (f  lub1) (g  lub2) (h  lub3) A"
by(simp add: generat_lub_def o_def)


abbreviation generat_lub' :: "('cont set  'cont')  ('a, 'out, 'cont) generat set  ('a, 'out, 'cont') generat"
where "generat_lub'  generat_lub (λA. THE x. x  A) (λA. THE x. x  A)"

fun rel_witness_generat :: "('a, 'c, 'e) generat × ('b, 'd, 'f) generat  ('a × 'b, 'c × 'd, 'e × 'f) generat" where
  "rel_witness_generat (Pure x, Pure y) = Pure (x, y)"
| "rel_witness_generat (IO out c, IO out' c') = IO (out, out') (c, c')"

lemma rel_witness_generat: 
  assumes "rel_generat A C R x y"
  shows pures_rel_witness_generat: "generat_pures (rel_witness_generat (x, y))  {(a, b). A a b}"
    and outs_rel_witness_generat: "generat_outs (rel_witness_generat (x, y))  {(c, d). C c d}"
    and conts_rel_witness_generat: "generat_conts (rel_witness_generat (x, y))  {(e, f). R e f}"
    and map1_rel_witness_generat: "map_generat fst fst fst (rel_witness_generat (x, y)) = x"
    and map2_rel_witness_generat: "map_generat snd snd snd (rel_witness_generat (x, y)) = y"
  using assms by(cases; simp; fail)+

lemmas set_rel_witness_generat = pures_rel_witness_generat outs_rel_witness_generat conts_rel_witness_generat

lemma rel_witness_generat1:
  assumes "rel_generat A C R x y"
  shows "rel_generat (λa (a', b). a = a'  A a' b) (λc (c', d). c = c'  C c' d) (λr (r', s). r = r'  R r' s) x (rel_witness_generat (x, y))"
  using map1_rel_witness_generat[OF assms, symmetric]
  unfolding generat.rel_eq[symmetric] generat.rel_map
  by(rule generat.rel_mono_strong)(auto dest: set_rel_witness_generat[OF assms, THEN subsetD])

lemma rel_witness_generat2:
  assumes "rel_generat A C R x y"
  shows "rel_generat (λ(a, b') b. b = b'  A a b') (λ(c, d') d. d = d'  C c d') (λ(r, s') s. s = s'  R r s') (rel_witness_generat (x, y)) y"
  using map2_rel_witness_generat[OF assms]
  unfolding generat.rel_eq[symmetric] generat.rel_map
  by(rule generat.rel_mono_strong)(auto dest: set_rel_witness_generat[OF assms, THEN subsetD])

end

Theory Generative_Probabilistic_Value

(* Title: Generative_Probabilistic_Value.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory Generative_Probabilistic_Value imports
  Resumption
  Generat
  "HOL-Types_To_Sets.Types_To_Sets"
begin

hide_const (open) Done

subsection ‹Type definition›

context notes [[bnf_internals]] begin

codatatype (results'_gpv: 'a, outs'_gpv: 'out, 'in) gpv
  = GPV (the_gpv: "('a, 'out, 'in  ('a, 'out, 'in) gpv) generat spmf")

end

declare gpv.rel_eq [relator_eq]

text ‹Reactive values are like generative, except that they take an input first.›

type_synonym ('a, 'out, 'in) rpv = "'in  ('a, 'out, 'in) gpv"
print_translation ― ‹pretty printing for @{typ "('a, 'out, 'in) rpv"} let
    fun tr' [in1, Const (@{type_syntax gpv}, _) $ a $ out $ in2] =
      if in1 = in2 then Syntax.const @{type_syntax rpv} $ a $ out $ in1
      else raise Match;
  in [(@{type_syntax "fun"}, K tr')]
  end
typ "('a, 'out, 'in) rpv"
text ‹
  Effectively, @{typ "('a, 'out, 'in) gpv"} and @{typ "('a, 'out, 'in) rpv"} are mutually recursive.
›

lemma eq_GPV_iff: "f = GPV g  the_gpv f = g"
by(cases f) auto

declare gpv.set[simp del]

declare gpv.set_map[simp]

lemma rel_gpv_def':
  "rel_gpv A B gpv gpv' 
  (gpv''. ((x, y)  results'_gpv gpv''. A x y)  ((x, y)  outs'_gpv gpv''. B x y) 
           map_gpv fst fst gpv'' = gpv  map_gpv snd snd gpv'' = gpv')"
unfolding rel_gpv_def by(auto simp add: BNF_Def.Grp_def)

definition results'_rpv :: "('a, 'out, 'in) rpv  'a set"
where "results'_rpv rpv = range rpv  results'_gpv"

definition outs'_rpv :: "('a, 'out, 'in) rpv  'out set"
where "outs'_rpv rpv = range rpv  outs'_gpv"

abbreviation rel_rpv
  :: "('a  'b  bool)  ('out  'out'  bool)
   ('in  ('a, 'out, 'in) gpv)  ('in  ('b, 'out', 'in) gpv)  bool"
where "rel_rpv A B  rel_fun (=) (rel_gpv A B)"

lemma in_results'_rpv [iff]: "x  results'_rpv rpv  (input. x  results'_gpv (rpv input))"
by(simp add: results'_rpv_def)

lemma in_outs_rpv [iff]: "out  outs'_rpv rpv  (input. out  outs'_gpv (rpv input))"
by(simp add: outs'_rpv_def)

lemma results'_GPV [simp]:
  "results'_gpv (GPV r) =
   (set_spmf r  generat_pures)  
   ((set_spmf r  generat_conts)  results'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_GPV [simp]:
  "outs'_gpv (GPV r) =
   (set_spmf r  generat_outs)  
   ((set_spmf r  generat_conts)  outs'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_gpv_unfold:
  "outs'_gpv r =
   (set_spmf (the_gpv r)  generat_outs)  
   ((set_spmf (the_gpv r)  generat_conts)  outs'_rpv)"
by(cases r) simp

lemma outs'_gpv_induct [consumes 1, case_names Out Cont, induct set: outs'_gpv]:
  assumes x: "x  outs'_gpv gpv"
  and Out: "generat gpv.  generat  set_spmf (the_gpv gpv); x  generat_outs generat   P gpv"
  and Cont: "generat gpv c input.
     generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  outs'_gpv (c input); P (c input)   P gpv"
  shows "P gpv"
using x
apply(induction y"x" gpv)
 apply(rule Out, simp add: in_set_spmf, simp)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma outs'_gpv_cases [consumes 1, case_names Out Cont, cases set: outs'_gpv]:
  assumes "x  outs'_gpv gpv"
  obtains (Out) generat where "generat  set_spmf (the_gpv gpv)" "x  generat_outs generat"
    | (Cont) generat c input where "generat  set_spmf (the_gpv gpv)" "c  generat_conts generat" "x  outs'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma outs'_gpvI [intro?]:
  shows outs'_gpv_Out: " generat  set_spmf (the_gpv gpv); x  generat_outs generat   x  outs'_gpv gpv"
  and outs'_gpv_Cont: " generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  outs'_gpv (c input)   x  outs'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma results'_gpv_induct [consumes 1, case_names Pure Cont, induct set: results'_gpv]:
  assumes x: "x  results'_gpv gpv"
  and Pure: "generat gpv.  generat  set_spmf (the_gpv gpv); x  generat_pures generat   P gpv"
  and Cont: "generat gpv c input.
     generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  results'_gpv (c input); P (c input)   P gpv"
  shows "P gpv"
using x
apply(induction y"x" gpv)
 apply(rule Pure; simp add: in_set_spmf)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma results'_gpv_cases [consumes 1, case_names Pure Cont, cases set: results'_gpv]:
  assumes "x  results'_gpv gpv"
  obtains (Pure) generat where "generat  set_spmf (the_gpv gpv)" "x  generat_pures generat"
    | (Cont) generat c input where "generat  set_spmf (the_gpv gpv)" "c  generat_conts generat" "x  results'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma results'_gpvI [intro?]:
  shows results'_gpv_Pure: " generat  set_spmf (the_gpv gpv); x  generat_pures generat   x  results'_gpv gpv"
  and results'_gpv_Cont: " generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  results'_gpv (c input)   x  results'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma left_unique_rel_gpv [transfer_rule]:
  " left_unique A; left_unique B   left_unique (rel_gpv A B)"
unfolding left_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_unique_rel_gpv [transfer_rule]:
  " right_unique A; right_unique B   right_unique (rel_gpv A B)"
unfolding right_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_unique_rel_gpv [transfer_rule]:
  " bi_unique A; bi_unique B   bi_unique (rel_gpv A B)"
unfolding bi_unique_alt_def by(simp add: left_unique_rel_gpv right_unique_rel_gpv)

lemma left_total_rel_gpv [transfer_rule]:
  " left_total A; left_total B   left_total (rel_gpv A B)"
unfolding left_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_total_rel_gpv [transfer_rule]:
  " right_total A; right_total B   right_total (rel_gpv A B)"
unfolding right_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_total_rel_gpv [transfer_rule]: " bi_total A; bi_total B   bi_total (rel_gpv A B)"
unfolding bi_total_alt_def by(simp add: left_total_rel_gpv right_total_rel_gpv)

declare gpv.map_transfer[transfer_rule]

lemma if_distrib_map_gpv [if_distribs]:
  "map_gpv f g (if b then gpv else gpv') = (if b then map_gpv f g gpv else map_gpv f g gpv')"
by simp

lemma gpv_pred_mono_strong:
  " pred_gpv P Q x; a.  a  results'_gpv x; P a   P' a; b.  b  outs'_gpv x; Q b   Q' b   pred_gpv P' Q' x"
by(simp add: pred_gpv_def)

lemma pred_gpv_top [simp]:
  "pred_gpv (λ_. True) (λ_. True) = (λ_. True)"
by(simp add: pred_gpv_def)

lemma pred_gpv_conj [simp]:
  shows pred_gpv_conj1: "P Q R. pred_gpv (λx. P x  Q x) R = (λx. pred_gpv P R x  pred_gpv Q R x)"
  and pred_gpv_conj2: "P Q R. pred_gpv P (λx. Q x  R x) = (λx. pred_gpv P Q x  pred_gpv P R x)"
by(auto simp add: pred_gpv_def)

lemma rel_gpv_restrict_relp1I [intro?]:
  " rel_gpv R R' x y; pred_gpv P P' x; pred_gpv Q Q' y   rel_gpv (R  P  Q) (R'  P'  Q') x y"
by(erule gpv.rel_mono_strong)(simp_all add: pred_gpv_def)

lemma rel_gpv_restrict_relpE [elim?]:
  assumes "rel_gpv (R  P  Q) (R'  P'  Q') x y"
  obtains "rel_gpv R R' x y" "pred_gpv P P' x" "pred_gpv Q Q' y"
proof
  show "rel_gpv R R' x y" using assms by(auto elim!: gpv.rel_mono_strong)
  have "pred_gpv (Domainp (R  P  Q)) (Domainp (R'  P'  Q')) x" using assms by(fold gpv.Domainp_rel) blast
  then show "pred_gpv P P' x" by(rule gpv_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)+
  have "pred_gpv (Domainp (R  P  Q)¯¯) (Domainp (R'  P'  Q')¯¯) y" using assms
    by(fold gpv.Domainp_rel)(auto simp only: gpv.rel_conversep Domainp_conversep)
  then show "pred_gpv Q Q' y" by(rule gpv_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma gpv_pred_map [simp]: "pred_gpv P Q (map_gpv f g gpv) = pred_gpv (P  f) (Q  g) gpv"
by(simp add: pred_gpv_def)

subsection ‹Generalised mapper and relator›

context includes lifting_syntax begin

primcorec map_gpv' :: "('a  'b)  ('out  'out')  ('ret'  'ret)  ('a, 'out, 'ret) gpv  ('b, 'out', 'ret') gpv"
where
  "map_gpv' f g h gpv = 
   GPV (map_spmf (map_generat f g ((∘) (map_gpv' f g h))) (map_spmf (map_generat id id (map_fun h id)) (the_gpv gpv)))"

declare map_gpv'.sel [simp del]

lemma map_gpv'_sel [simp]:
  "the_gpv (map_gpv' f g h gpv) = map_spmf (map_generat f g (h ---> map_gpv' f g h)) (the_gpv gpv)"
by(simp add: map_gpv'.sel spmf.map_comp o_def generat.map_comp map_fun_def[abs_def])

lemma map_gpv'_GPV [simp]:
  "map_gpv' f g h (GPV p) = GPV (map_spmf (map_generat f g (h ---> map_gpv' f g h)) p)"
by(rule gpv.expand) simp

lemma map_gpv'_id: "map_gpv' id id id = id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)
done

lemma map_gpv'_comp: "map_gpv' f g h (map_gpv' f' g' h' gpv) = map_gpv' (f  f') (g  g') (h'  h) gpv"
by(coinduction arbitrary: gpv)(auto simp add: spmf.map_comp spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)

functor gpv: map_gpv' by(simp_all add: map_gpv'_comp map_gpv'_id o_def) 

lemma map_gpv_conv_map_gpv': "map_gpv f g = map_gpv' f g id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: gpv.map_sel spmf_rel_map generat.rel_map rel_fun_def intro!: generat.rel_refl_strong rel_spmf_reflI)
done

coinductive rel_gpv'' :: "('a  'b  bool)  ('out  'out'  bool)  ('ret  'ret'  bool)  ('a, 'out, 'ret) gpv  ('b, 'out', 'ret') gpv  bool"
  for A C R
where
  "rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')
   rel_gpv'' A C R gpv gpv'"

lemma rel_gpv''_coinduct [consumes 1, case_names rel_gpv'', coinduct pred: rel_gpv'']:
  "X gpv gpv';
    gpv gpv'. X gpv gpv'
      rel_spmf (rel_generat A C (R ===> (λgpv gpv'. X gpv gpv'  rel_gpv'' A C R gpv gpv')))
           (the_gpv gpv) (the_gpv gpv') 
    rel_gpv'' A C R gpv gpv'"
by(erule rel_gpv''.coinduct) blast

lemma rel_gpv''D:
  "rel_gpv'' A C R gpv gpv' 
   rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')"
by(simp add: rel_gpv''.simps)

lemma rel_gpv''_GPV [simp]:
  "rel_gpv'' A C R (GPV p) (GPV q) 
   rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"
by(simp add: rel_gpv''.simps)

lemma rel_gpv_conv_rel_gpv'': "rel_gpv A C = rel_gpv'' A C (=)"
proof(rule ext iffI)+
  show "rel_gpv A C gpv gpv'" if "rel_gpv'' A C (=) gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
    using that by(coinduct)(blast dest: rel_gpv''D)
  show "rel_gpv'' A C (=) gpv gpv'" if "rel_gpv A C gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
    using that by(coinduct)(auto elim!: gpv.rel_cases rel_spmf_mono generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_eq (* [relator_eq] do not use this attribute unless all transfer rules for gpv have been changed to rel_gvp'' *):
  "rel_gpv'' (=) (=) (=) = (=)"
by(simp add: rel_gpv_conv_rel_gpv''[symmetric] gpv.rel_eq)

lemma rel_gpv''_mono:
  assumes "A  A'" "C  C'" "R'  R"
  shows "rel_gpv'' A C R  rel_gpv'' A' C' R'"
proof
  show "rel_gpv'' A' C' R' gpv gpv'" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that
    by(coinduct)(auto dest: rel_gpv''D elim!: rel_spmf_mono generat.rel_mono_strong rel_fun_mono intro: assms[THEN predicate2D])
qed

lemma rel_gpv''_conversep: "rel_gpv'' A¯¯ C¯¯ R¯¯ = (rel_gpv'' A C R)¯¯"
proof(intro ext iffI; simp)
  show "rel_gpv'' A C R gpv gpv'" if "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv"
    for A :: "'a1  'a2  bool" and C :: "'c1  'c2  bool" and R :: "'r1  'r2  bool" and gpv gpv'
    using that apply(coinduct)
    apply(drule rel_gpv''D)
    apply(rewrite in  conversep_iff[symmetric])
    apply(subst spmf_rel_conversep[symmetric])
    apply(erule rel_spmf_mono)
    apply(subst generat.rel_conversep[symmetric])
    apply(erule generat.rel_mono_strong)
    apply(auto simp add: rel_fun_def conversep_iff[abs_def])
    done
  from this[of "A¯¯" "C¯¯" "R¯¯"]
  show "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that by simp
qed


lemma rel_gpv''_pos_distr:
  "rel_gpv'' A C R OO rel_gpv'' A' C' R'  rel_gpv'' (A OO A') (C OO C') (R OO R')"
proof(rule predicate2I; erule relcomppE)
  show "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
    if "rel_gpv'' A C R gpv gpv'" "rel_gpv'' A' C' R' gpv' gpv''"
    for gpv gpv' gpv'' using that
    apply(coinduction arbitrary: gpv gpv' gpv'')
    apply(drule rel_gpv''D)+
    apply(drule (1) rel_spmf_pos_distr[THEN predicate2D, OF relcomppI])
    apply(erule spmf_rel_mono_strong)
    apply(subst (asm) generat.rel_compp[symmetric])
    apply(erule generat.rel_mono_strong, assumption, assumption)
    apply(drule pos_fun_distr[THEN predicate2D])
    apply(auto simp add: rel_fun_def)
    done
qed

lemma left_unique_rel_gpv'':
  " left_unique A; left_unique C; left_total R   left_unique (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma right_unique_rel_gpv'':
  " right_unique A; right_unique C; right_total R   right_unique (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma bi_unique_rel_gpv'' [transfer_rule]:
  " bi_unique A; bi_unique C; bi_total R   bi_unique (rel_gpv'' A C R)"
unfolding bi_unique_alt_def bi_total_alt_def by(blast intro: left_unique_rel_gpv'' right_unique_rel_gpv'')

lemma rel_gpv''_map_gpv1:
  "rel_gpv'' A C R (map_gpv f g gpv) gpv' = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R gpv gpv'" (is "?lhs = ?rhs")
proof
  show ?rhs if ?lhs using that
    apply(coinduction arbitrary: gpv gpv')
    apply(drule rel_gpv''D)
    apply(simp add: gpv.map_sel spmf_rel_map)
    apply(erule rel_spmf_mono)
    by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
  show ?lhs if ?rhs using that
    apply(coinduction arbitrary: gpv gpv')
    apply(drule rel_gpv''D)
    apply(simp add: gpv.map_sel spmf_rel_map)
    apply(erule rel_spmf_mono)
    by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_map_gpv2:
  "rel_gpv'' A C R gpv (map_gpv f g gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv gpv'"
  using rel_gpv''_map_gpv1[of "conversep A" "conversep C" "conversep R" f g gpv' gpv]
  apply(rewrite in " = _" conversep_iff[symmetric])
  apply(rewrite in "_ = " conversep_iff[symmetric])
  apply(simp only: rel_gpv''_conversep)
  apply(simp only: rel_gpv''_conversep[symmetric])
  apply(simp only: conversep_iff[abs_def])
  done

lemmas rel_gpv''_map_gpv = rel_gpv''_map_gpv1[abs_def] rel_gpv''_map_gpv2

lemma rel_gpv''_map_gpv' [simp]:
  shows "f g h gpv. NO_MATCH id f  NO_MATCH id g 
     rel_gpv'' A C R (map_gpv' f g h gpv) = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R (map_gpv' id id h gpv)"
    and "f g h gpv gpv'. NO_MATCH id f  NO_MATCH id g 
     rel_gpv'' A C R gpv (map_gpv' f g h gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv (map_gpv' id id h gpv')"
proof (goal_cases)
  case (1 f g h gpv)
  then show ?case using map_gpv'_comp[of f g id id id h gpv, symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
next
  case (2 f g h gpv gpv')
  then show ?case using map_gpv'_comp[of f g id id id h gpv', symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
qed

lemmas rel_gpv_map_gpv' = rel_gpv''_map_gpv'[where R="(=)", folded rel_gpv_conv_rel_gpv'']

definition rel_witness_gpv :: "('a  'd  bool)  ('b  'e  bool)  ('c  'g  bool)  ('g  'f  bool)  ('a, 'b, 'c) gpv × ('d, 'e, 'f) gpv  ('a × 'd, 'b × 'e, 'g) gpv" where
  "rel_witness_gpv A C R R' = corec_gpv (
     map_spmf (map_generat id id (λ(rpv, rpv'). (Inr  rel_witness_fun R R' (rpv, rpv')))  rel_witness_generat) 
     rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R'))))  map_prod the_gpv the_gpv)"

lemma rel_witness_gpv_sel [simp]:
  "the_gpv (rel_witness_gpv A C R R' (gpv, gpv')) = 
    map_spmf (map_generat id id (λ(rpv, rpv'). (rel_witness_gpv A C R R'  rel_witness_fun R R' (rpv, rpv')))  rel_witness_generat)
     (rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R')))) (the_gpv gpv, the_gpv gpv'))"
  unfolding rel_witness_gpv_def
  by(auto simp add: spmf.map_comp generat.map_comp o_def intro!: map_spmf_cong generat.map_cong)

lemma assumes "rel_gpv'' A C (R OO R') gpv gpv'"
  and R: "left_unique R" "right_total R"
  and R': "right_unique R'" "left_total R'"
shows rel_witness_gpv1: "rel_gpv'' (λa (a', b). a = a'  A a' b) (λc (c', d). c = c'  C c' d) R gpv (rel_witness_gpv A C R R' (gpv, gpv'))" (is "?thesis1")
  and rel_witness_gpv2: "rel_gpv'' (λ(a, b') b. b = b'  A a b') (λ(c, d') d. d = d'  C c d') R' (rel_witness_gpv A C R R' (gpv, gpv')) gpv'" (is "?thesis2")
proof -
  show ?thesis1 using assms(1)
  proof(coinduction arbitrary: gpv gpv')
    case rel_gpv''
    from this[THEN rel_gpv''D] show ?case
      by(auto simp add: spmf_rel_map generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun1[OF _ R R']]
          rel_spmf_mono[OF rel_witness_spmf1] generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat1])
  qed
  show ?thesis2 using assms(1)
  proof(coinduction arbitrary: gpv gpv')
    case rel_gpv''
    from this[THEN rel_gpv''D] show ?case
      by(simp add: spmf_rel_map) 
        (erule rel_spmf_mono[OF rel_witness_spmf2]
          , auto simp add: generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun2[OF _ R R']]
          generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat2])
  qed
qed

lemma rel_gpv''_neg_distr:
  assumes R: "left_unique R" "right_total R"
    and R': "right_unique R'" "left_total R'"
  shows "rel_gpv'' (A OO A') (C OO C') (R OO R')  rel_gpv'' A C R OO rel_gpv'' A' C' R'"
proof(rule predicate2I relcomppI)+
  fix gpv gpv''
  assume *: "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
  let ?gpv' = "map_gpv (relcompp_witness A A') (relcompp_witness C C') (rel_witness_gpv (A OO A') (C OO C') R R' (gpv, gpv''))"
  show "rel_gpv'' A C R gpv ?gpv'" using rel_witness_gpv1[OF * R R'] unfolding rel_gpv''_map_gpv
    by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
  show "rel_gpv'' A' C' R' ?gpv' gpv''" using rel_witness_gpv2[OF * R R'] unfolding rel_gpv''_map_gpv
    by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
qed

lemma rel_gpv''_mono' [mono]:
  assumes "x y. A x y  A' x y"
    and "x y. C x y  C' x y"
    and "x y. R' x y  R x y"
  shows "rel_gpv'' A C R gpv gpv'  rel_gpv'' A' C' R' gpv gpv'"
  using rel_gpv''_mono[of A A' C C' R' R] assms by(blast)

lemma left_total_rel_gpv':
  " left_total A; left_total C; left_unique R; right_total R   left_total (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_neg_distr; simp add: left_unique_alt_def)
apply(rule rel_gpv''_mono; assumption)
done

lemma right_total_rel_gpv':
  " right_total A; right_total C; right_unique R; left_total R   right_total (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_neg_distr; simp add: right_unique_alt_def)
apply(rule rel_gpv''_mono; assumption)
done

lemma bi_total_rel_gpv' [transfer_rule]:
  " bi_total A; bi_total C; bi_unique R; bi_total R   bi_total (rel_gpv'' A C R)"
unfolding bi_total_alt_def bi_unique_alt_def by(blast intro: left_total_rel_gpv' right_total_rel_gpv')

lemma rel_fun_conversep_grp_grp:
  "rel_fun (conversep (BNF_Def.Grp UNIV f)) (BNF_Def.Grp B g) = BNF_Def.Grp {x. (x  f) ` UNIV  B} (map_fun f g)"
unfolding rel_fun_def Grp_def simp_thms fun_eq_iff conversep_iff by auto

lemma Quotient_gpv:
  assumes Q1: "Quotient R1 Abs1 Rep1 T1"
  and Q2: "Quotient R2 Abs2 Rep2 T2"
  and Q3: "Quotient R3 Abs3 Rep3 T3"
  shows "Quotient (rel_gpv'' R1 R2 R3) (map_gpv' Abs1 Abs2 Rep3) (map_gpv' Rep1 Rep2 Abs3) (rel_gpv'' T1 T2 T3)"
  (is "Quotient ?R ?abs ?rep ?T")
unfolding Quotient_alt_def2
proof(intro conjI strip iffI; (elim conjE exE)?)
  note [simp] = spmf_rel_map generat.rel_map
    and [elim!] = rel_spmf_mono generat.rel_mono_strong
    and [rule del] = rel_funI and [intro!] = rel_funI
  have Abs1 [simp]: "Abs1 x = y" if "T1 x y" for x y using Q1 that by(simp add: Quotient_alt_def)
  have Abs2 [simp]: "Abs2 x = y" if "T2 x y" for x y using Q2 that by(simp add: Quotient_alt_def)
  have Abs3 [simp]: "Abs3 x = y" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def)
  have Rep1: "T1 (Rep1 x) x" for x using Q1 by(simp add: Quotient_alt_def)
  have Rep2: "T2 (Rep2 x) x" for x using Q2 by(simp add: Quotient_alt_def)
  have Rep3: "T3 (Rep3 x) x" for x using Q3 by(simp add: Quotient_alt_def)
  have T1: "T1 x (Abs1 y)" if "R1 x y" for x y using Q1 that by(simp add: Quotient_alt_def2)
  have T2: "T2 x (Abs2 y)" if "R2 x y" for x y using Q2 that by(simp add: Quotient_alt_def2)
  have T1': "T1 x (Abs1 y)" if "R1 y x" for x y using Q1 that by(simp add: Quotient_alt_def2)
  have T2': "T2 x (Abs2 y)" if "R2 y x" for x y using Q2 that by(simp add: Quotient_alt_def2)
  have R3: "R3 x (Rep3 y)" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
  have R3': "R3 (Rep3 y) x" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
  have r1: "R1 = T1 OO T1¯¯" using Q1 by(simp add: Quotient_alt_def4)
  have r2: "R2 = T2 OO T2¯¯" using Q2 by(simp add: Quotient_alt_def4)
  have r3: "R3 = T3 OO T3¯¯" using Q3 by(simp add: Quotient_alt_def4)
  show abs: "?abs gpv = gpv'" if "?T gpv gpv'" for gpv gpv' using that
    by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 4 intro: Rep3 dest: rel_funD)
  show "?T (?rep gpv) gpv" for gpv
    by(coinduction arbitrary: gpv)(auto simp add: Rep1 Rep2 intro!: rel_spmf_reflI generat.rel_refl_strong)
  show "?T gpv (?abs gpv')" if "?R gpv gpv'" for gpv gpv' using that
    by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1 T2 intro!: R3 dest: rel_funD)
  show "?T gpv (?abs gpv')" if "?R gpv' gpv" for gpv gpv'
  proof -
    from that have "rel_gpv'' R1¯¯ R2¯¯ R3¯¯ gpv gpv'" unfolding rel_gpv''_conversep by simp
    then show ?thesis
      by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1' T2' intro!: R3' dest: rel_funD)
  qed
  show "?R gpv gpv'" if "?T gpv (?abs gpv')" "?T gpv' (?abs gpv)" for gpv gpv'
  proof -
    from that[THEN abs] have "?abs gpv' = ?abs gpv" by simp
    with that have "(?T OO ?T¯¯) gpv gpv'" by(auto simp del: rel_gpv''_map_gpv')
    hence "rel_gpv'' (T1 OO T1¯¯) (T2 OO T2¯¯) (T3 OO T3¯¯) gpv gpv'"
      unfolding rel_gpv''_conversep[symmetric]
      by(rule rel_gpv''_pos_distr[THEN predicate2D])
    thus ?thesis by(simp add: r1 r2 r3)
  qed
qed

lemma the_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R))) the_gpv the_gpv"
by(rule rel_funI)(auto elim: rel_gpv''.cases)

lemma GPV_parametric':
  "(rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) ===> rel_gpv'' A C R) GPV GPV"
by(rule rel_funI)(auto)

lemma corec_gpv_parametric':
  "((S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) ===> S ===> rel_gpv'' A C R)
  corec_gpv corec_gpv"
proof(rule rel_funI)+
  fix f g s1 s2
  assume fg: "(S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) f g"
    and s: "S s1 s2"
  from s show "rel_gpv'' A C R (corec_gpv f s1) (corec_gpv g s2)"
    apply(coinduction arbitrary: s1 s2)
    apply(drule fg[THEN rel_funD])
    apply(simp add: spmf_rel_map)
    apply(erule rel_spmf_mono)
    apply(simp add: generat.rel_map)
    apply(erule generat.rel_mono_strong; clarsimp simp add: o_def)
    apply(rule rel_funI)
    apply(drule (1) rel_funD)
    apply(auto 4 3 elim!: rel_sum.cases)
    done
qed

lemma map_gpv'_parametric [transfer_rule]:
  "((A ===> A') ===> (C ===> C') ===> (R' ===> R) ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R') map_gpv' map_gpv'"
  unfolding map_gpv'_def
  supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
  by(transfer_prover)

lemma map_gpv_parametric': "((A ===> A') ===> (C ===> C') ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R) map_gpv map_gpv"
  unfolding map_gpv_conv_map_gpv'[abs_def] by transfer_prover

end

subsection ‹Simple, derived operations›

primcorec Done :: "'a  ('a, 'out, 'in) gpv"
where "the_gpv (Done a) = return_spmf (Pure a)"

primcorec Pause :: "'out  ('in  ('a, 'out, 'in) gpv)  ('a, 'out, 'in) gpv"
where "the_gpv (Pause out c) = return_spmf (IO out c)"

primcorec lift_spmf :: "'a spmf  ('a, 'out, 'in) gpv"
where "the_gpv (lift_spmf p) = map_spmf Pure p"

definition Fail :: "('a, 'out, 'in) gpv"
where "Fail = GPV (return_pmf None)"

definition React :: "('in  'out × ('a, 'out, 'in) rpv)  ('a, 'out, 'in) rpv"
where "React f input = case_prod Pause (f input)"

definition rFail :: "('a, 'out, 'in) rpv"
where "rFail = (λ_. Fail)"

lemma Done_inject [simp]: "Done x = Done y  x = y"
by(simp add: Done.ctr)

lemma Pause_inject [simp]: "Pause out c = Pause out' c'  out = out'  c = c'"
by(simp add: Pause.ctr)

lemma [simp]:
  shows Done_neq_Pause: "Done x  Pause out c"
  and Pause_neq_Done: "Pause out c  Done x"
by(simp_all add: Done.ctr Pause.ctr)

lemma outs'_gpv_Done [simp]: "outs'_gpv (Done x) = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Done [simp]: "results'_gpv (Done x) = {x}"
by(auto intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Done [simp]: "pred_gpv P Q (Done x) = P x"
by(simp add: pred_gpv_def)

lemma outs'_gpv_Pause [simp]: "outs'_gpv (Pause out c) = insert out (input. outs'_gpv (c input))"
by(auto 4 4 intro: outs'_gpvI elim: outs'_gpv_cases)

lemma results'_gpv_Pause [simp]: "results'_gpv (Pause out rpv) = results'_rpv rpv"
by(auto 4 4 intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Pause [simp]: "pred_gpv P Q (Pause x c) = (Q x  All (pred_gpv P Q  c))"
by(auto simp add: pred_gpv_def o_def)

lemma lift_spmf_return [simp]: "lift_spmf (return_spmf x) = Done x"
by(simp add: lift_spmf.ctr Done.ctr)

lemma lift_spmf_None [simp]: "lift_spmf (return_pmf None) = Fail"
by(rule gpv.expand)(simp add: Fail_def)

lemma the_gpv_lift_spmf [simp]: "the_gpv (lift_spmf r) = map_spmf Pure r"
by(simp)

lemma outs'_gpv_lift_spmf [simp]: "outs'_gpv (lift_spmf p) = {}"
by(auto 4 3 elim: outs'_gpv_cases)

lemma results'_gpv_lift_spmf [simp]: "results'_gpv (lift_spmf p) = set_spmf p"
by(auto 4 3 elim: results'_gpv_cases intro: results'_gpvI)

lemma pred_gpv_lift_spmf [simp]: "pred_gpv P Q (lift_spmf p) = pred_spmf P p"
by(simp add: pred_gpv_def pred_spmf_def)

lemma lift_spmf_inject [simp]: "lift_spmf p = lift_spmf q  p = q"
by(auto simp add: lift_spmf.code dest!: pmf.inj_map_strong[rotated] option.inj_map_strong[rotated])

lemma map_lift_spmf: "map_gpv f g (lift_spmf p) = lift_spmf (map_spmf f p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma lift_map_spmf: "lift_spmf (map_spmf f p) = map_gpv f id (lift_spmf p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma [simp]:
  shows Fail_neq_Pause: "Fail  Pause out c"
  and Pause_neq_Fail: "Pause out c  Fail"
  and Fail_neq_Done: "Fail  Done x"
  and Done_neq_Fail: "Done x  Fail"
by(simp_all add: Fail_def Pause.ctr Done.ctr)

text ‹Add @{typ unit} closure to circumvent SML value restriction›

definition Fail' :: "unit  ('a, 'out, 'in) gpv"
where [code del]: "Fail' _ = Fail"

lemma Fail_code [code_unfold]: "Fail = Fail' ()"
by(simp add: Fail'_def)

lemma Fail'_code [code]:
  "Fail' x = GPV (return_pmf None)"
by(simp add: Fail'_def Fail_def)

lemma Fail_sel [simp]:
  "the_gpv Fail = return_pmf None"
by(simp add: Fail_def)

lemma Fail_eq_GPV_iff [simp]: "Fail = GPV f  f = return_pmf None"
by(auto simp add: Fail_def)

lemma outs'_gpv_Fail [simp]: "outs'_gpv Fail = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Fail [simp]: "results'_gpv Fail = {}"
by(auto elim: results'_gpv_cases)

lemma pred_gpv_Fail [simp]: "pred_gpv P Q Fail"
by(simp add: pred_gpv_def)

lemma React_inject [iff]: "React f = React f'  f = f'"
by(auto simp add: React_def fun_eq_iff split_def intro: prod.expand)

lemma React_apply [simp]: "f input = (out, c)  React f input = Pause out c"
by(simp add: React_def)

lemma rFail_apply [simp]: "rFail input = Fail"
by(simp add: rFail_def)

lemma [simp]:
  shows rFail_neq_React: "rFail  React f"
  and React_neq_rFail: "React f  rFail"
by(simp_all add: React_def fun_eq_iff split_beta)

lemma rel_gpv_FailI [simp]: "rel_gpv A C Fail Fail"
by(subst gpv.rel_sel) simp

lemma rel_gpv_Done [iff]: "rel_gpv A C (Done x) (Done y)  A x y"
by(subst gpv.rel_sel) simp

lemma rel_gpv''_Done [iff]: "rel_gpv'' A C R (Done x) (Done y)  A x y"
by(subst rel_gpv''.simps) simp

lemma rel_gpv_Pause [iff]:
  "rel_gpv A C (Pause out c) (Pause out' c')  C out out'  (x. rel_gpv A C (c x) (c' x))"
by(subst gpv.rel_sel)(simp add: rel_fun_def)

lemma rel_gpv''_Pause [iff]:
  "rel_gpv'' A C R (Pause out c) (Pause out' c')  C out out'  (x x'. R x x'  rel_gpv'' A C R (c x) (c' x'))"
by(subst rel_gpv''.simps)(simp add: rel_fun_def)

lemma rel_gpv_lift_spmf [iff]: "rel_gpv A C (lift_spmf p) (lift_spmf q)  rel_spmf A p q"
by(subst gpv.rel_sel)(simp add: spmf_rel_map)

lemma rel_gpv''_lift_spmf [iff]:
  "rel_gpv'' A C R (lift_spmf p) (lift_spmf q)  rel_spmf A p q"
by(subst rel_gpv''.simps)(simp add: spmf_rel_map)

context includes lifting_syntax begin
lemmas Fail_parametric [transfer_rule] = rel_gpv_FailI

lemma Fail_parametric' [simp]: "rel_gpv'' A C R Fail Fail"
unfolding Fail_def by simp

lemma Done_parametric [transfer_rule]: "(A ===> rel_gpv A C) Done Done"
by(rule rel_funI) simp

lemma Done_parametric': "(A ===> rel_gpv'' A C R) Done Done"
by(rule rel_funI) simp

lemma Pause_parametric [transfer_rule]:
  "(C ===> ((=) ===> rel_gpv A C) ===> rel_gpv A C) Pause Pause"
by(simp add: rel_fun_def)

lemma Pause_parametric':
  "(C ===> (R ===> rel_gpv'' A C R) ===> rel_gpv'' A C R) Pause Pause"
by(simp add: rel_fun_def)

lemma lift_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> rel_gpv A C) lift_spmf lift_spmf"
by(simp add: rel_fun_def)

lemma lift_spmf_parametric':
  "(rel_spmf A ===> rel_gpv'' A C R) lift_spmf lift_spmf"
by(simp add: rel_fun_def)
end

lemma map_gpv_Done [simp]: "map_gpv f g (Done x) = Done (f x)"
by(simp add: Done.code)

lemma map_gpv'_Done [simp]: "map_gpv' f g h (Done x) = Done (f x)"
by(simp add: Done.code)

lemma map_gpv_Pause [simp]: "map_gpv f g (Pause x c) = Pause (g x) (map_gpv f g  c)"
by(simp add: Pause.code)

lemma map_gpv'_Pause [simp]: "map_gpv' f g h (Pause x c) = Pause (g x) (map_gpv' f g h  c  h)"
by(simp add: Pause.code map_fun_def)

lemma map_gpv_Fail [simp]: "map_gpv f g Fail = Fail"
by(simp add: Fail_def)

lemma map_gpv'_Fail [simp]: "map_gpv' f g h Fail = Fail"
by(simp add: Fail_def)

subsection ‹Monad structure›

primcorec bind_gpv :: "('a, 'out, 'in) gpv  ('a  ('b, 'out, 'in) gpv)  ('b, 'out, 'in) gpv"
where
  "the_gpv (bind_gpv r f) =
   map_spmf (map_generat id id ((∘) (case_sum id (λr. bind_gpv r f))))
     (the_gpv r 
      (case_generat
        (λx. map_spmf (map_generat id id ((∘) Inl)) (the_gpv (f x)))
        (λout c. return_spmf (IO out (λinput. Inr (c input))))))"

declare bind_gpv.sel [simp del]

adhoc_overloading Monad_Syntax.bind bind_gpv

lemma bind_gpv_unfold [code]:
  "r  f = GPV (
   do {
     generat  the_gpv r;
     case generat of Pure x  the_gpv (f x)
       | IO out c  return_spmf (IO out (λinput. c input  f))
   })"
unfolding bind_gpv_def
apply(rule gpv.expand)
apply(simp add: map_spmf_bind_spmf)
apply(rule arg_cong[where f="bind_spmf (the_gpv r)"])
apply(auto split: generat.split simp add: map_spmf_bind_spmf fun_eq_iff spmf.map_comp o_def generat.map_comp id_def[symmetric] generat.map_id pmf.map_id option.map_id)
done

lemma bind_gpv_code_cong: "f = f'  bind_gpv f g = bind_gpv f' g" by simp
setup Code_Simp.map_ss (Simplifier.add_cong @{thm bind_gpv_code_cong})

lemma bind_gpv_sel:
  "the_gpv (r  f) =
   do {
     generat  the_gpv r;
     case generat of Pure x  the_gpv (f x)
       | IO out c  return_spmf (IO out (λinput. bind_gpv (c input) f))
   }"
by(subst bind_gpv_unfold) simp

lemma bind_gpv_sel' [simp]:
  "the_gpv (r  f) =
   do {
     generat  the_gpv r;
     if is_Pure generat then the_gpv (f (result generat))
     else return_spmf (IO (output generat) (λinput. bind_gpv (continuation generat input) f))
   }"
unfolding bind_gpv_sel
by(rule arg_cong[where f="bind_spmf (the_gpv r)"])(simp add: fun_eq_iff split: generat.split)

lemma Done_bind_gpv [simp]: "Done a  f = f a"
by(rule gpv.expand)(simp)

lemma bind_gpv_Done [simp]: "f  Done = f"
proof(coinduction arbitrary: f rule: gpv.coinduct)
  case (Eq_gpv f)
  have *: "the_gpv f  (case_generat (λx. return_spmf (Pure x)) (λout c. return_spmf (IO out (λinput. Inr (c input))))) =
           map_spmf (map_generat id id ((∘) Inr)) (bind_spmf (the_gpv f) return_spmf)"
    unfolding map_spmf_bind_spmf
    by(rule arg_cong2[where f=bind_spmf])(auto simp add: fun_eq_iff split: generat.split)
  show ?case
    by(auto simp add: * bind_gpv.simps pmf.rel_map option.rel_map[abs_def] generat.rel_map[abs_def] simp del: bind_gpv_sel' intro!: rel_generatI rel_spmf_reflI)
qed

lemma if_distrib_bind_gpv2 [if_distribs]:
  "bind_gpv gpv (λy. if b then f y else g y) = (if b then bind_gpv gpv f else bind_gpv gpv g)"
by simp

lemma lift_spmf_bind: "lift_spmf r  f = GPV (r  the_gpv  f)"
by(coinduction arbitrary: r f rule: gpv.coinduct_strong)(auto simp add: bind_map_spmf o_def intro: rel_pmf_reflI rel_optionI rel_generatI)

lemma the_gpv_bind_gpv_lift_spmf [simp]:
  "the_gpv (bind_gpv (lift_spmf p) f) = bind_spmf p (the_gpv  f)"
by(simp add: bind_map_spmf o_def)

lemma lift_spmf_bind_spmf: "lift_spmf (p  f) = lift_spmf p  (λx. lift_spmf (f x))"
by(rule gpv.expand)(simp add: lift_spmf_bind o_def map_spmf_bind_spmf)

lemma lift_bind_spmf: "lift_spmf (bind_spmf p f) = bind_gpv (lift_spmf p) (lift_spmf  f)"
by(rule gpv.expand)(simp add: bind_map_spmf map_spmf_bind_spmf o_def)

lemma GPV_bind:
  "GPV f  g = 
   GPV (f  (λgenerat. case generat of Pure x  the_gpv (g x) | IO out c  return_spmf (IO out (λinput. c input  g))))"
by(subst bind_gpv_unfold) simp

lemma GPV_bind':
  "GPV f  g = GPV (f  (λgenerat. if is_Pure generat then the_gpv (g (result generat)) else return_spmf (IO (output generat) (λinput. continuation generat input  g))))"
unfolding GPV_bind gpv.inject
by(rule arg_cong[where f="bind_spmf f"])(simp add: fun_eq_iff split: generat.split)

lemma bind_gpv_assoc:
  fixes f :: "('a, 'out, 'in) gpv"
  shows "(f  g)  h = f  (λx. g x  h)"
proof(coinduction arbitrary: f g h rule: gpv.coinduct_strong)
  case (Eq_gpv f g h)
  show ?case
    apply(simp cong del: if_weak_cong)
    apply(rule rel_spmf_bindI[where R="(=)"])
     apply(simp add: option.rel_eq pmf.rel_eq)
    apply(fastforce intro: rel_pmf_return_pmfI rel_generatI rel_spmf_reflI)
    done
qed

lemma map_gpv_bind_gpv: "map_gpv f g (bind_gpv gpv h) = bind_gpv (map_gpv id g gpv) (λx. map_gpv f g (h x))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(simp add: bind_gpv.sel gpv.map_sel spmf_rel_map generat.rel_map o_def bind_map_spmf del: bind_gpv_sel')
apply(rule rel_spmf_bind_reflI)
apply(auto simp add: spmf_rel_map generat.rel_map split: generat.split del: rel_funI intro!: rel_spmf_reflI generat.rel_refl rel_funI)
done

lemma map_gpv_id_bind_gpv: "map_gpv f id (bind_gpv gpv g) = bind_gpv gpv (map_gpv f id  g)"
by(simp add: map_gpv_bind_gpv gpv.map_id o_def)

lemma map_gpv_conv_bind:
  "map_gpv f (λx. x) x = bind_gpv x (λx. Done (f x))"
using map_gpv_bind_gpv[of f "λx. x" x Done] by(simp add: id_def[symmetric] gpv.map_id)

lemma bind_map_gpv: "bind_gpv (map_gpv f id gpv) g = bind_gpv gpv (g  f)"
by(simp add: map_gpv_conv_bind id_def bind_gpv_assoc o_def)

lemma outs_bind_gpv:
  "outs'_gpv (bind_gpv x f) = outs'_gpv x  (x  results'_gpv x. outs'_gpv (f x))"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  fix out
  assume "out  ?lhs"
  then show "out  ?rhs"
  proof(induction g"x  f" arbitrary: x)
    case (Out generat)
    then obtain generat' where *: "generat'  set_spmf (the_gpv x)"
      and **: "generat  set_spmf (if is_Pure generat' then the_gpv (f (result generat'))
                                else return_spmf (IO (output generat') (λinput. continuation generat' input  f)))"
      by(auto)
    show ?case
    proof(cases "is_Pure generat'")
      case True
      then have "out  outs'_gpv (f (result generat'))" using Out(2) ** by(auto intro: outs'_gpvI)
      moreover have "result generat'  results'_gpv x" using * True
        by(auto intro: results'_gpvI generat.set_sel)
      ultimately show ?thesis by blast
    next
      case False
      hence "out  outs'_gpv x" using * ** Out(2) by(auto intro: outs'_gpvI generat.set_sel)
      thus ?thesis by blast
    qed
  next
    case (Cont generat c input)
    then obtain generat' where *: "generat'  set_spmf (the_gpv x)"
      and **: "generat  set_spmf (if is_Pure generat' then the_gpv (f (generat.result generat'))
                                 else return_spmf (IO (generat.output generat') (λinput. continuation generat' input  f)))"
      by(auto)
    show ?case
    proof(cases "is_Pure generat'")
      case True
      then have "out  outs'_gpv (f (result generat'))" using Cont(2-3) ** by(auto intro: outs'_gpvI)
      moreover have "result generat'  results'_gpv x" using * True
        by(auto intro: results'_gpvI generat.set_sel)
      ultimately show ?thesis by blast
    next
      case False
      then have generat: "generat = IO (output generat') (λinput. continuation generat' input  f)"
        using ** by simp
      with Cont(2) have "c input = continuation generat' input  f" by auto
      hence "out  outs'_gpv (continuation generat' input)  (xresults'_gpv (continuation generat' input). outs'_gpv (f x))"
        by(rule Cont)
      thus ?thesis
      proof
        assume "out  outs'_gpv (continuation generat' input)"
        with * ** False have "out  outs'_gpv x" by(auto intro: outs'_gpvI generat.set_sel)
        thus ?thesis ..
      next
        assume "out  (xresults'_gpv (continuation generat' input). outs'_gpv (f x))"
        then obtain y where "y  results'_gpv (continuation generat' input)" "out  outs'_gpv (f y)" ..
        from y  _ * ** False have "y  results'_gpv x" 
          by(auto intro: results'_gpvI generat.set_sel)
        with out  outs'_gpv (f y) show ?thesis by blast
      qed
    qed
  qed
next
  fix out
  assume "out  ?rhs"
  then show "out  ?lhs"
  proof
    assume "out  outs'_gpv x"
    thus ?thesis
    proof(induction)
      case (Out generat gpv)
      then show ?case
        by(cases generat)(fastforce intro: outs'_gpvI rev_bexI)+
    next
      case (Cont generat gpv gpv')
      then show ?case
        by(cases generat)(auto 4 4 intro: outs'_gpvI rev_bexI simp add: in_set_spmf set_pmf_bind_spmf simp del: set_bind_spmf)
    qed
  next
    assume "out  (xresults'_gpv x. outs'_gpv (f x))"
    then obtain y where "y  results'_gpv x" "out  outs'_gpv (f y)" ..
    from y  _ show ?thesis
    proof(induction)
      case (Pure generat gpv)
      thus ?case using out  outs'_gpv _
        by(cases generat)(auto 4 5 intro: outs'_gpvI rev_bexI elim: outs'_gpv_cases)
    next
      case (Cont generat gpv gpv')
      thus ?case
        by(cases generat)(auto 4 4 simp add: in_set_spmf simp add: set_pmf_bind_spmf intro: outs'_gpvI rev_bexI simp del: set_bind_spmf)
    qed
  qed
qed

lemma bind_gpv_Fail [simp]: "Fail  f = Fail"
by(subst bind_gpv_unfold)(simp add: Fail_def)

lemma bind_gpv_eq_Fail:
  "bind_gpv gpv f = Fail  (xset_spmf (the_gpv gpv). is_Pure x)  (xresults'_gpv gpv. f x = Fail)"
  (is "?lhs = ?rhs")
proof(intro iffI conjI strip)
  show ?lhs if ?rhs using that
    by(intro gpv.expand)(auto 4 4 simp add: bind_eq_return_pmf_None intro: results'_gpv_Pure generat.set_sel dest: bspec)

  assume ?lhs
  hence *: "the_gpv (bind_gpv gpv f) = return_pmf None" by simp
  from * show "is_Pure x" if "x  set_spmf (the_gpv gpv)" for x using that
    by(simp add: bind_eq_return_pmf_None split: if_split_asm)
  show "f x = Fail" if "x  results'_gpv gpv" for x using that *
    by(cases)(auto 4 3 simp add: bind_eq_return_pmf_None elim!: generat.set_cases intro: gpv.expand dest: bspec)
qed

context includes lifting_syntax begin

lemma bind_gpv_parametric [transfer_rule]:
  "(rel_gpv A C ===> (A ===> rel_gpv B C) ===> rel_gpv B C) bind_gpv bind_gpv"
unfolding bind_gpv_def by transfer_prover

lemma bind_gpv_parametric':
  "(rel_gpv'' A C R ===> (A ===> rel_gpv'' B C R) ===> rel_gpv'' B C R) bind_gpv bind_gpv"
unfolding bind_gpv_def supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by(transfer_prover)

end

lemma monad_gpv [locale_witness]: "monad Done bind_gpv"
by(unfold_locales)(simp_all add: bind_gpv_assoc)

lemma monad_fail_gpv [locale_witness]: "monad_fail Done bind_gpv Fail"
by unfold_locales auto

lemma rel_gpv_bindI:
  " rel_gpv A C gpv gpv'; x y. A x y  rel_gpv B C (f x) (g y) 
   rel_gpv B C (bind_gpv gpv f) (bind_gpv gpv' g)"
by(fact bind_gpv_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])

lemma bind_gpv_cong:
  " gpv = gpv'; x. x  results'_gpv gpv'  f x = g x   bind_gpv gpv f = bind_gpv gpv' g"
apply(subst gpv.rel_eq[symmetric])
apply(rule rel_gpv_bindI[where A="eq_onp (λx. x  results'_gpv gpv')"])
 apply(subst (asm) gpv.rel_eq[symmetric])
 apply(erule gpv.rel_mono_strong)
  apply(simp add: eq_onp_def)
 apply simp
apply(clarsimp simp add: gpv.rel_eq eq_onp_def)
done

definition bind_rpv :: "('a, 'in, 'out) rpv  ('a  ('b, 'in, 'out) gpv)  ('b, 'in, 'out) rpv"
where "bind_rpv rpv f = (λinput. bind_gpv (rpv input) f)"

lemma bind_rpv_apply [simp]: "bind_rpv rpv f input = bind_gpv (rpv input) f"
by(simp add: bind_rpv_def fun_eq_iff)

adhoc_overloading Monad_Syntax.bind bind_rpv

lemma bind_rpv_code_cong: "rpv = rpv'  bind_rpv rpv f = bind_rpv rpv' f" by simp
setup Code_Simp.map_ss (Simplifier.add_cong @{thm bind_rpv_code_cong})

lemma bind_rpv_rDone [simp]: "bind_rpv rpv Done = rpv"
by(simp add: bind_rpv_def)

lemma bind_gpv_Pause [simp]: "bind_gpv (Pause out rpv) f = Pause out (bind_rpv rpv f)"
by(rule gpv.expand)(simp add: fun_eq_iff)

lemma bind_rpv_React [simp]: "bind_rpv (React f) g = React (apsnd (λrpv. bind_rpv rpv g)  f)"
by(simp add: React_def split_beta fun_eq_iff)

lemma bind_rpv_assoc: "bind_rpv (bind_rpv rpv f) g = bind_rpv rpv ((λgpv. bind_gpv gpv g)  f)"
by(simp add: fun_eq_iff bind_gpv_assoc o_def)

lemma bind_rpv_Done [simp]: "bind_rpv Done f = f"
by(simp add: bind_rpv_def)

lemma results'_rpv_Done [simp]: "results'_rpv Done = UNIV"
by(auto simp add: results'_rpv_def)


subsection ‹ Embedding @{typ "'a spmf"} as a monad ›

lemma neg_fun_distr3:
  includes lifting_syntax
  assumes 1: "left_unique R" "right_total R"
  assumes 2: "right_unique S" "left_total S"
  shows "(R OO R' ===> S OO S')  ((R ===> S) OO (R' ===> S'))"
using functional_relation[OF 2] functional_converse_relation[OF 1]
unfolding rel_fun_def OO_def
apply clarify
apply (subst all_comm)
apply (subst all_conj_distrib[symmetric])
apply (intro choice)
by metis

locale spmf_to_gpv begin

text ‹
  The lifting package cannot handle free term variables in the merging of transfer rules,
  so for the embedding we define a specialised relator rel_gpv'›
  which acts only on the returned values.
›

definition rel_gpv' :: "('a  'b  bool)  ('a, 'out, 'in) gpv  ('b, 'out, 'in) gpv  bool"
where "rel_gpv' A = rel_gpv A (=)"

lemma rel_gpv'_eq [relator_eq]: "rel_gpv' (=) = (=)"
unfolding rel_gpv'_def gpv.rel_eq ..

lemma rel_gpv'_mono [relator_mono]: "A  B  rel_gpv' A  rel_gpv' B"
unfolding rel_gpv'_def by(rule gpv.rel_mono; simp)

lemma rel_gpv'_distr [relator_distr]: "rel_gpv' A OO rel_gpv' B = rel_gpv' (A OO B)"
unfolding rel_gpv'_def by (metis OO_eq gpv.rel_compp) 

lemma left_unique_rel_gpv' [transfer_rule]: "left_unique A  left_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_unique_rel_gpv left_unique_eq)

lemma right_unique_rel_gpv' [transfer_rule]: "right_unique A  right_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_unique_rel_gpv right_unique_eq)

lemma bi_unique_rel_gpv' [transfer_rule]: "bi_unique A  bi_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_unique_rel_gpv bi_unique_eq)

lemma left_total_rel_gpv' [transfer_rule]: "left_total A  left_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_total_rel_gpv left_total_eq)

lemma right_total_rel_gpv' [transfer_rule]: "right_total A  right_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_total_rel_gpv right_total_eq)

lemma bi_total_rel_gpv' [transfer_rule]: "bi_total A  bi_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_total_rel_gpv bi_total_eq)


text ‹
  We cannot use setup_lifting› because @{typ "('a, 'out, 'in) gpv"} contains
  type variables which do not appear in @{typ "'a spmf"}.
›

definition cr_spmf_gpv :: "'a spmf  ('a, 'out, 'in) gpv  bool"
where "cr_spmf_gpv p gpv  gpv = lift_spmf p"

definition spmf_of_gpv :: "('a, 'out, 'in) gpv  'a spmf"
where "spmf_of_gpv gpv = (THE p. gpv = lift_spmf p)"

lemma spmf_of_gpv_lift_spmf [simp]: "spmf_of_gpv (lift_spmf p) = p"
unfolding spmf_of_gpv_def by auto

lemma rel_spmf_setD2:
  " rel_spmf A p q; y  set_spmf q   xset_spmf p. A x y"
by(erule rel_spmfE) force

lemma rel_gpv_lift_spmf1: "rel_gpv A B (lift_spmf p) gpv  (q. gpv = lift_spmf q  rel_spmf A p q)"
apply(subst gpv.rel_sel)
apply(simp add: spmf_rel_map rel_generat_Pure1)
apply safe
 apply(rule exI[where x="map_spmf result (the_gpv gpv)"])
 apply(clarsimp simp add: spmf_rel_map)
 apply(rule conjI)
  apply(rule gpv.expand)
  apply(simp add: spmf.map_comp)
  apply(subst map_spmf_cong[OF refl, where g=id])
   apply(drule (1) rel_spmf_setD2)
   apply clarsimp
  apply simp
 apply(erule rel_spmf_mono)
 apply clarsimp
apply(clarsimp simp add: spmf_rel_map)
done

lemma rel_gpv_lift_spmf2: "rel_gpv A B gpv (lift_spmf q)  (p. gpv = lift_spmf p  rel_spmf A p q)"
by(subst gpv.rel_flip[symmetric])(simp add: rel_gpv_lift_spmf1 pmf.rel_flip option.rel_conversep)

definition pcr_spmf_gpv :: "('a  'b  bool)  'a spmf  ('b, 'out, 'in) gpv  bool"
where "pcr_spmf_gpv A = cr_spmf_gpv OO rel_gpv A (=)"

lemma pcr_cr_eq_spmf_gpv: "pcr_spmf_gpv (=) = cr_spmf_gpv"
by(simp add: pcr_spmf_gpv_def gpv.rel_eq OO_eq)

lemma left_unique_cr_spmf_gpv: "left_unique cr_spmf_gpv"
by(rule left_uniqueI)(simp add: cr_spmf_gpv_def)

lemma left_unique_pcr_spmf_gpv [transfer_rule]:
  "left_unique A  left_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_unique_OO left_unique_cr_spmf_gpv left_unique_rel_gpv left_unique_eq)

lemma right_unique_cr_spmf_gpv: "right_unique cr_spmf_gpv"
by(rule right_uniqueI)(simp add: cr_spmf_gpv_def)

lemma right_unique_pcr_spmf_gpv [transfer_rule]:
  "right_unique A  right_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro right_unique_OO right_unique_cr_spmf_gpv right_unique_rel_gpv right_unique_eq)

lemma bi_unique_cr_spmf_gpv: "bi_unique cr_spmf_gpv"
by(simp add: bi_unique_alt_def left_unique_cr_spmf_gpv right_unique_cr_spmf_gpv)

lemma bi_unique_pcr_spmf_gpv [transfer_rule]: "bi_unique A  bi_unique (pcr_spmf_gpv A)"
by(simp add: bi_unique_alt_def left_unique_pcr_spmf_gpv right_unique_pcr_spmf_gpv)

lemma left_total_cr_spmf_gpv: "left_total cr_spmf_gpv"
by(rule left_totalI)(simp add: cr_spmf_gpv_def)

lemma left_total_pcr_spmf_gpv [transfer_rule]: "left_total A ==> left_total (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_total_OO left_total_cr_spmf_gpv left_total_rel_gpv left_total_eq)

context includes lifting_syntax begin

lemma return_spmf_gpv_transfer':
  "((=) ===> cr_spmf_gpv) return_spmf Done"
by(rule rel_funI)(simp add: cr_spmf_gpv_def)

lemma return_spmf_gpv_transfer [transfer_rule]:
  "(A ===> pcr_spmf_gpv A) return_spmf Done"
unfolding pcr_spmf_gpv_def
apply(rewrite in "( ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule return_spmf_gpv_transfer')
apply transfer_prover
done

lemma bind_spmf_gpv_transfer':
  "(cr_spmf_gpv ===> ((=) ===> cr_spmf_gpv) ===> cr_spmf_gpv) bind_spmf bind_gpv"
apply(clarsimp simp add: rel_fun_def cr_spmf_gpv_def)
apply(rule gpv.expand)
apply(simp add: bind_map_spmf map_spmf_bind_spmf o_def)
done

lemma bind_spmf_gpv_transfer [transfer_rule]:
  "(pcr_spmf_gpv A ===> (A ===> pcr_spmf_gpv B) ===> pcr_spmf_gpv B) bind_spmf bind_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(_ ===> ( ===> _) ===> _) _ _" eq_OO[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule fun_mono)
  apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_cr_spmf_gpv left_total_cr_spmf_gpv])
 apply(rule order.refl)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule bind_spmf_gpv_transfer')
apply transfer_prover
done

lemma lift_spmf_gpv_transfer':
  "((=) ===> cr_spmf_gpv) (λx. x) lift_spmf"
by(simp add: rel_fun_def cr_spmf_gpv_def)

lemma lift_spmf_gpv_transfer [transfer_rule]:
  "(rel_spmf A ===> pcr_spmf_gpv A) (λx. x) lift_spmf"
unfolding pcr_spmf_gpv_def
apply(rewrite in "( ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule lift_spmf_gpv_transfer')
apply transfer_prover
done

lemma fail_spmf_gpv_transfer': "cr_spmf_gpv (return_pmf None) Fail"
by(simp add: cr_spmf_gpv_def)

lemma fail_spmf_gpv_transfer [transfer_rule]: "pcr_spmf_gpv A (return_pmf None) Fail"
unfolding pcr_spmf_gpv_def
apply(rule relcomppI)
 apply(rule fail_spmf_gpv_transfer')
apply transfer_prover
done

lemma map_spmf_gpv_transfer':
  "((=) ===> R ===> cr_spmf_gpv ===> cr_spmf_gpv) (λf g. map_spmf f) map_gpv"
by(simp add: rel_fun_def cr_spmf_gpv_def map_lift_spmf)

lemma map_spmf_gpv_transfer [transfer_rule]:
  "((A ===> B) ===> R ===> pcr_spmf_gpv A ===> pcr_spmf_gpv B) (λf g. map_spmf f) map_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(( ===> _) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "((_ ===> ) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "(_ ===>  ===> _)  _ _" OO_eq[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_eq left_total_eq])
 apply(rule fun_mono[OF order.refl])
 apply(rule pos_fun_distr)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(unfold rel_fun_eq)
 apply(rule map_spmf_gpv_transfer')
apply(unfold rel_fun_eq[symmetric])
apply transfer_prover
done

end

end

subsection ‹ Embedding @{typ "'a option"} as a monad ›

locale option_to_gpv begin

interpretation option_to_spmf .
interpretation spmf_to_gpv .

definition cr_option_gpv :: "'a option  ('a, 'out, 'in) gpv  bool"
where "cr_option_gpv x gpv  gpv = (lift_spmf  return_pmf) x"

lemma cr_option_gpv_conv_OO:
  "cr_option_gpv = cr_spmf_option¯¯ OO cr_spmf_gpv"
by(simp add: fun_eq_iff relcompp.simps cr_option_gpv_def cr_spmf_gpv_def cr_spmf_option_def)

context includes lifting_syntax begin

text ‹These transfer rules should follow from merging the transfer rules, but this has not yet been implemented.›

lemma return_option_gpv_transfer [transfer_rule]:
  "((=) ===> cr_option_gpv) Some Done"
by(simp add: cr_option_gpv_def rel_fun_def)

lemma bind_option_gpv_transfer [transfer_rule]:
  "(cr_option_gpv ===> ((=) ===> cr_option_gpv) ===> cr_option_gpv) Option.bind bind_gpv"
apply(clarsimp simp add: cr_option_gpv_def rel_fun_def)
subgoal for x f g by(cases x; simp)
done

lemma fail_option_gpv_transfer [transfer_rule]: "cr_option_gpv None Fail"
by(simp add: cr_option_gpv_def)

lemma map_option_gpv_transfer [transfer_rule]:
  "((=) ===> R ===> cr_option_gpv ===> cr_option_gpv) (λf g. map_option f) map_gpv"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_gpv_def map_lift_spmf)

end

end

locale option_le_gpv begin

interpretation option_le_spmf .
interpretation spmf_to_gpv .

definition cr_option_le_gpv :: "'a option  ('a, 'out, 'in) gpv  bool"
where "cr_option_le_gpv x gpv  gpv = (lift_spmf  return_pmf) x  x = None"

context includes lifting_syntax begin

lemma return_option_le_gpv_transfer [transfer_rule]:
  "((=) ===> cr_option_le_gpv) Some Done"
by(simp add: cr_option_le_gpv_def rel_fun_def)

lemma bind_option_gpv_transfer [transfer_rule]:
  "(cr_option_le_gpv ===> ((=) ===> cr_option_le_gpv) ===> cr_option_le_gpv) Option.bind bind_gpv"
apply(clarsimp simp add: cr_option_le_gpv_def rel_fun_def bind_eq_Some_conv)
subgoal for f g x y by(erule allE[where x=y]) auto
done

lemma fail_option_gpv_transfer [transfer_rule]:
  "cr_option_le_gpv None Fail"
by(simp add: cr_option_le_gpv_def)

lemma map_option_gpv_transfer [transfer_rule]:
  "(((=) ===> (=)) ===> cr_option_le_gpv ===> cr_option_le_gpv) map_option (λf. map_gpv f id)"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_le_gpv_def map_lift_spmf)

end

end

subsection ‹Embedding resumptions›

primcorec lift_resumption :: "('a, 'out, 'in) resumption  ('a, 'out, 'in) gpv"
where
  "the_gpv (lift_resumption r) = 
  (case r of resumption.Done None  return_pmf None
    | resumption.Done (Some x') => return_spmf (Pure x')
    | resumption.Pause out c => map_spmf (map_generat id id ((∘) lift_resumption)) (return_spmf (IO out c)))"

lemma the_gpv_lift_resumption:
  "the_gpv (lift_resumption r) = 
   (if is_Done r then if Option.is_none (resumption.result r) then return_pmf None else return_spmf (Pure (the (resumption.result r)))
    else return_spmf (IO (resumption.output r) (lift_resumption  resume r)))"
by(simp split: option.split resumption.split)

declare lift_resumption.simps [simp del]

lemma lift_resumption_Done [code]:
  "lift_resumption (resumption.Done x) = (case x of None  Fail | Some x'  Done x')"
by(rule gpv.expand)(simp add: the_gpv_lift_resumption split: option.split)

lemma lift_resumption_DONE [simp]:
  "lift_resumption (DONE x) = Done x"
by(simp add: DONE_def lift_resumption_Done)

lemma lift_resumption_ABORT [simp]:
  "lift_resumption ABORT = Fail"
by(simp add: ABORT_def lift_resumption_Done)

lemma lift_resumption_Pause [simp, code]:
  "lift_resumption (resumption.Pause out c) = Pause out (lift_resumption  c)"
by(rule gpv.expand)(simp add: the_gpv_lift_resumption)

lemma lift_resumption_Done_Some [simp]: "lift_resumption (resumption.Done (Some x)) = Done x"
using lift_resumption_DONE unfolding DONE_def by simp

lemma results'_gpv_lift_resumption [simp]:
  "results'_gpv (lift_resumption r) = results r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv"lift_resumption r" arbitrary: r)
      (auto intro: resumption.set_sel simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
  show "x  ?lhs" if "x  ?rhs" for x using that by induction(auto simp add: lift_resumption.sel)
qed

lemma outs'_gpv_lift_resumption [simp]:
  "outs'_gpv (lift_resumption r) = outputs r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv"lift_resumption r" arbitrary: r)
      (auto simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
  show "x  ?lhs" if "x  ?rhs" for x using that by induction auto
qed

lemma pred_gpv_lift_resumption [simp]:
  "A. pred_gpv A C (lift_resumption r) = pred_resumption A C r"
by(simp add: pred_gpv_def pred_resumption_def)

lemma lift_resumption_bind: "lift_resumption (r  f) = lift_resumption r  lift_resumption  f"
by(coinduction arbitrary: r rule: gpv.coinduct_strong)
  (auto simp add: lift_resumption.sel Done_bind split: resumption.split option.split del: rel_funI intro!: rel_funI)

subsection ‹Assertions›

definition assert_gpv :: "bool  (unit, 'out, 'in) gpv"
where "assert_gpv b = (if b then Done () else Fail)"

lemma assert_gpv_simps [simp]:
  "assert_gpv True = Done ()"
  "assert_gpv False = Fail"
by(simp_all add: assert_gpv_def)

lemma [simp]:
  shows assert_gpv_eq_Done: "assert_gpv b = Done x  b"
  and Done_eq_assert_gpv: "Done x = assert_gpv b  b"
  and Pause_neq_assert_gpv: "Pause out rpv  assert_gpv b"
  and assert_gpv_neq_Pause: "assert_gpv b  Pause out rpv"
  and assert_gpv_eq_Fail: "assert_gpv b = Fail  ¬ b"
  and Fail_eq_assert_gpv: "Fail = assert_gpv b  ¬ b"
by(simp_all add: assert_gpv_def)

lemma assert_gpv_inject [simp]: "assert_gpv b = assert_gpv b'  b = b'"
by(simp add: assert_gpv_def)

lemma assert_gpv_sel [simp]:
  "the_gpv (assert_gpv b) = map_spmf Pure (assert_spmf b)"
by(simp add: assert_gpv_def)

lemma the_gpv_bind_assert [simp]:
  "the_gpv (bind_gpv (assert_gpv b) f) =
   bind_spmf (assert_spmf b) (the_gpv  f)"
by(cases b) simp_all

lemma pred_gpv_assert [simp]: "pred_gpv P Q (assert_gpv b) = (b  P ())"
by(cases b) simp_all

primcorec try_gpv :: "('a, 'call, 'ret) gpv  ('a, 'call, 'ret) gpv  ('a, 'call, 'ret) gpv" ("TRY _ ELSE _" [0,60] 59)
where
  "the_gpv (TRY gpv ELSE gpv') = 
   map_spmf (map_generat id id (λc input. case c input of Inl gpv  try_gpv gpv gpv' | Inr gpv'  gpv'))
     (try_spmf (map_spmf (map_generat id id (map_fun id Inl)) (the_gpv gpv))
               (map_spmf (map_generat id id (map_fun id Inr)) (the_gpv gpv')))"

lemma try_gpv_sel:
  "the_gpv (TRY gpv ELSE gpv') =
   TRY map_spmf (map_generat id id (λc input. TRY c input ELSE gpv')) (the_gpv gpv) ELSE the_gpv gpv'"
by(simp add: try_gpv_def map_try_spmf spmf.map_comp o_def generat.map_comp generat.map_ident id_def)

lemma try_gpv_Done [simp]: "TRY Done x ELSE gpv' = Done x"
by(rule gpv.expand)(simp)

lemma try_gpv_Fail [simp]: "TRY Fail ELSE gpv' = gpv'"
by(rule gpv.expand)(simp add: spmf.map_comp o_def generat.map_comp generat.map_ident)

lemma try_gpv_Pause [simp]: "TRY Pause out c ELSE gpv' = Pause out (λinput. TRY c input ELSE gpv')"
by(rule gpv.expand) simp

lemma try_gpv_Fail2 [simp]: "TRY gpv ELSE Fail = gpv"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
  (auto simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl)

lemma lift_try_spmf: "lift_spmf (TRY p ELSE q) = TRY lift_spmf p ELSE lift_spmf q" 
by(rule gpv.expand)(simp add: map_try_spmf spmf.map_comp o_def)

lemma try_assert_gpv: "TRY assert_gpv b ELSE gpv' = (if b then Done () else gpv')"
by(simp)

context includes lifting_syntax begin
lemma try_gpv_parametric [transfer_rule]:
  "(rel_gpv A C ===> rel_gpv A C ===> rel_gpv A C) try_gpv try_gpv"
unfolding try_gpv_def by transfer_prover

lemma try_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_gpv'' A C R ===> rel_gpv'' A C R) try_gpv try_gpv"
unfolding try_gpv_def
supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by transfer_prover
end

lemma map_try_gpv: "map_gpv f g (TRY gpv ELSE gpv') = TRY map_gpv f g gpv ELSE map_gpv f g gpv'"
by(simp add: gpv.rel_map try_gpv_parametric[THEN rel_funD, THEN rel_funD] gpv.rel_refl gpv.rel_eq[symmetric])

lemma map'_try_gpv: "map_gpv' f g h (TRY gpv ELSE gpv') = TRY map_gpv' f g h gpv ELSE map_gpv' f g h gpv'"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)(auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl rel_funI rel_spmf_try_spmf)
  

lemma try_bind_assert_gpv:
  "TRY (assert_gpv b  f) ELSE gpv = (if b then TRY (f ()) ELSE gpv else gpv)"
by(simp)



subsection ‹Order for @{typ "('a, 'out, 'in) gpv"}

coinductive ord_gpv :: "('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv  bool"
where
  "ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) f g  ord_gpv (GPV f) (GPV g)"

inductive_simps ord_gpv_simps [simp]:
  "ord_gpv (GPV f) (GPV g)"

lemma ord_gpv_coinduct [consumes 1, case_names ord_gpv, coinduct pred: ord_gpv]:
  assumes "X f g"
  and step: "f g. X f g  ord_spmf (rel_generat (=) (=) (rel_fun (=) X)) (the_gpv f) (the_gpv g)"
  shows "ord_gpv f g"
using X f g
by(coinduct)(auto dest: step simp add: eq_GPV_iff intro: ord_spmf_mono rel_generat_mono rel_fun_mono)

lemma ord_gpv_the_gpvD:
  "ord_gpv f g  ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) (the_gpv f) (the_gpv g)"
by(erule ord_gpv.cases) simp

lemma reflp_equality: "reflp (=)"
by(simp add: reflp_def)

lemma ord_gpv_reflI [simp]: "ord_gpv f f"
by(coinduction arbitrary: f)(auto intro: ord_spmf_reflI simp add: rel_generat_same rel_fun_def)

lemma reflp_ord_gpv: "reflp ord_gpv"
by(rule reflpI)(rule ord_gpv_reflI)

lemma ord_gpv_trans:
  assumes "ord_gpv f g" "ord_gpv g h"
  shows "ord_gpv f h"
using assms
proof(coinduction arbitrary: f g h)
  case (ord_gpv f g h)
  have *: "ord_spmf (rel_generat (=) (=) (rel_fun (=) (λf h. g. ord_gpv f g  ord_gpv g h))) (the_gpv f) (the_gpv h) =
    ord_spmf (rel_generat ((=) OO (=)) ((=) OO (=)) (rel_fun (=) (ord_gpv OO ord_gpv))) (the_gpv f) (the_gpv h)"
    by(simp add: relcompp.simps[abs_def])
  then show ?case using ord_gpv
    by(auto elim!: ord_gpv.cases simp add: generat.rel_compp ord_spmf_compp fun.rel_compp)
qed

lemma ord_gpv_compp: "(ord_gpv OO ord_gpv) = ord_gpv"
by(auto simp add: fun_eq_iff intro: ord_gpv_trans)

lemma transp_ord_gpv [simp]: "transp ord_gpv"
by(blast intro: transpI ord_gpv_trans)

lemma ord_gpv_antisym:
  " ord_gpv f g; ord_gpv g f   f = g"
proof(coinduction arbitrary: f g)
  case (Eq_gpv f g)
  let ?R = "rel_generat (=) (=) (rel_fun (=) ord_gpv)"
  from ‹ord_gpv f g have "ord_spmf ?R (the_gpv f) (the_gpv g)" by cases simp
  moreover
  from ‹ord_gpv g f have "ord_spmf ?R (the_gpv g) (the_gpv f)" by cases simp
  ultimately have "rel_spmf (inf ?R ?R¯¯) (the_gpv f) (the_gpv g)"
    by(rule rel_spmf_inf)(auto 4 3 intro: transp_rel_generatI transp_ord_gpv reflp_ord_gpv reflp_equality reflp_fun1 is_equality_eq transp_rel_fun)
  also have "inf ?R ?R¯¯ = rel_generat (inf (=) (=)) (inf (=) (=)) (rel_fun (=) (inf ord_gpv ord_gpv¯¯))"
    unfolding rel_generat_inf[symmetric] rel_fun_inf[symmetric]
    by(simp add: generat.rel_conversep[symmetric] fun.rel_conversep)
  finally show ?case by(simp add: inf_fun_def)
qed

lemma RFail_least [simp]: "ord_gpv Fail f"
by(coinduction arbitrary: f)(simp add: eq_GPV_iff)

subsection ‹Bounds on interaction›

context
  fixes "consider" :: "'out  bool"
  notes monotone_SUP[partial_function_mono] [[function_internals]]
begin
declaration Partial_Function.init "lfp_strong" @{term lfp.fixp_fun} @{term lfp.mono_body}
  @{thm lfp.fixp_rule_uc} @{thm lfp.fixp_induct_strong2_uc} NONE›

partial_function (lfp_strong) interaction_bound :: "('a, 'out, 'in) gpv  enat"
where
  "interaction_bound gpv =
  (SUP generatset_spmf (the_gpv gpv). case generat of Pure _  0 
     | IO out c  if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"

lemma interaction_bound_fixp_induct [case_names adm bottom step]:
  " ccpo.admissible (fun_lub Sup) (fun_ord (≤)) P;
     P (λ_. 0);
    interaction_bound'. 
     P interaction_bound'; 
      gpv. interaction_bound' gpv  interaction_bound gpv;
      gpv. interaction_bound' gpv  (SUP generatset_spmf (the_gpv gpv). case generat of Pure _  0 
     | IO out c  if consider out then eSuc (SUP input. interaction_bound' (c input)) else (SUP input. interaction_bound' (c input)))
      
       P (λgpv. generatset_spmf (the_gpv gpv). case generat of Pure x  0
         | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else (input. interaction_bound' (c input))) 
    P interaction_bound"
by(erule interaction_bound.fixp_induct)(simp_all add: bot_enat_def fun_ord_def)

lemma interaction_bound_IO:
   "IO out c  set_spmf (the_gpv gpv)
    (if consider out then eSuc (interaction_bound (c input)) else interaction_bound (c input))  interaction_bound gpv"
by(rewrite in "_  " interaction_bound.simps)(auto intro!: SUP_upper2)

lemma interaction_bound_IO_consider:
   " IO out c  set_spmf (the_gpv gpv); consider out 
    eSuc (interaction_bound (c input))  interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_IO_ignore:
   " IO out c  set_spmf (the_gpv gpv); ¬ consider out 
    interaction_bound (c input)  interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_Done [simp]: "interaction_bound (Done x) = 0"
by(simp add: interaction_bound.simps)

lemma interaction_bound_Fail [simp]: "interaction_bound Fail = 0"
by(simp add: interaction_bound.simps bot_enat_def)

lemma interaction_bound_Pause [simp]:
  "interaction_bound (Pause out c) = 
   (if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"
by(simp add: interaction_bound.simps)

lemma interaction_bound_lift_spmf [simp]: "interaction_bound (lift_spmf p) = 0"
by(simp add: interaction_bound.simps SUP_constant bot_enat_def)

lemma interaction_bound_assert_gpv [simp]: "interaction_bound (assert_gpv b) = 0"
by(cases b) simp_all

lemma interaction_bound_bind_step:
  assumes IH: "p. interaction_bound' (p  f)  interaction_bound p + (xresults'_gpv p. interaction_bound' (f x))"
  and unfold: "gpv. interaction_bound' gpv  (generatset_spmf (the_gpv gpv). case generat of Pure x  0
             | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else input. interaction_bound' (c input))"
  shows "(generatset_spmf (the_gpv (p  f)).
             case generat of Pure x  0
             | IO out c 
                 if consider out then eSuc (input. interaction_bound' (c input))
                 else input. interaction_bound' (c input))
          interaction_bound p +
            (xresults'_gpv p.
                generatset_spmf (the_gpv (f x)).
                   case generat of Pure x  0
                   | IO out c 
                       if consider out then eSuc (input. interaction_bound' (c input))
                       else input. interaction_bound' (c input))"
    (is "(SUP generat'?bind. ?g generat')  ?p + ?f")
proof(rule SUP_least)
  fix generat'
  assume "generat'  ?bind"
  then obtain generat where generat: "generat  set_spmf (the_gpv p)"
    and *: "case generat of Pure x  generat'  set_spmf (the_gpv (f x)) 
         | IO out c  generat' = IO out (λinput. c input  f)"
    by(clarsimp simp add: bind_gpv.sel simp del: bind_gpv_sel')
      (clarsimp split: generat.split_asm simp add: generat.map_comp o_def generat.map_id[unfolded id_def])
  show "?g generat'  ?p + ?f"
  proof(cases generat)
    case (Pure x)
    have "?g generat'  (SUP generat'set_spmf (the_gpv (f x)). (case generat' of Pure x  0 | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else input. interaction_bound' (c input)))"
      using * Pure by(auto intro: SUP_upper)
    also have "  0 + ?f" using generat Pure
      by(auto 4 3 intro: SUP_upper results'_gpv_Pure)
    also have "  ?p + ?f" by simp
    finally show ?thesis .
  next
    case (IO out c)
    with * have "?g generat' = (if consider out then eSuc (SUP input. interaction_bound' (c input  f)) else (SUP input. interaction_bound' (c input  f)))" by simp
    also have "  (if consider out then eSuc (SUP input. interaction_bound (c input) + (xresults'_gpv (c input). interaction_bound' (f x))) else (SUP input. interaction_bound (c input) + (xresults'_gpv (c input). interaction_bound' (f x))))"
      by(auto intro: SUP_mono IH)
    also have "  (case IO out c of Pure (x :: 'a)  0 | IO out c  if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input))) + (SUP input. SUP xresults'_gpv (c input). interaction_bound' (f x))"
      by(simp add: iadd_Suc SUP_le_iff)(meson SUP_upper2 UNIV_I add_mono order_refl)
    also have "  ?p + ?f"
      apply(rewrite in "_  " interaction_bound.simps)
      apply(rule add_mono SUP_least SUP_upper generat[unfolded IO])+
      apply(rule order_trans[OF unfold])
      apply(auto 4 3 intro: results'_gpv_Cont[OF generat] SUP_upper simp add: IO)
      done
    finally show ?thesis .
  qed
qed

lemma interaction_bound_bind:
  defines "ib1  interaction_bound"
  shows "interaction_bound (p  f)  ib1 p + (SUP xresults'_gpv p. interaction_bound (f x))"
proof(induction arbitrary: p rule: interaction_bound_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step interaction_bound') then show ?case unfolding ib1_def by -(rule interaction_bound_bind_step)
qed

lemma interaction_bound_bind_lift_spmf [simp]:
  "interaction_bound (lift_spmf p  f) = (SUP xset_spmf p. interaction_bound (f x))"
by(subst (1 2) interaction_bound.simps)(simp add: bind_UNION SUP_UNION)

end

lemma interaction_bound_map_gpv':
  assumes "surj h"
  shows "interaction_bound consider (map_gpv' f g h gpv) = interaction_bound (consider  g) gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF lattice_partial_function_definition lattice_partial_function_definition interaction_bound.mono interaction_bound.mono interaction_bound_def interaction_bound_def, case_names adm bottom step])
  case (step interaction_bound' interaction_bound'' gpv)
  have *: "IO out c  set_spmf (the_gpv gpv)   x  UNIV  interaction_bound'' (c x)  (x. interaction_bound'' (c (h x)))" for out c x
    using assms[THEN surjD, of x] by (clarsimp intro!: SUP_upper)

  show ?case 
    by (auto simp add: * step.IH image_comp split: generat.split
      intro!: SUP_cong [OF refl] antisym SUP_upper SUP_least)
qed simp_all

abbreviation interaction_any_bound :: "('a, 'out, 'in) gpv  enat"
where "interaction_any_bound  interaction_bound (λ_. True)"

lemma interaction_any_bound_coinduct [consumes 1, case_names interaction_bound]:
  assumes X: "X gpv n"
  and *: "gpv n out c input.  X gpv n; IO out c  set_spmf (the_gpv gpv)  
     n'. (X (c input) n'  interaction_any_bound (c input)  n')  eSuc n'  n"
  shows "interaction_any_bound gpv  n"
using X
proof(induction arbitrary: gpv n rule: interaction_bound_fixp_induct)
  case adm show ?case by(intro cont_intro)
  case bottom show ?case by simp
next
  case (step interaction_bound')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    from *[OF step.prems IO] obtain n' where n: "n = eSuc n'"
      by(cases n rule: co.enat.exhaust) auto
    moreover 
    { fix input
      have "n''. (X (c input) n''  interaction_any_bound (c input)  n'')  eSuc n''  n"
        using step.prems IO n = eSuc n' by(auto 4 3 dest: *)
      then have "interaction_bound' (c input)  n'" using n
        by(auto dest: step.IH intro: step.hyps[THEN order_trans] elim!: order_trans simp add: neq_zero_conv_eSuc) }
    ultimately have "eSuc (input. interaction_bound' (c input))  n"
      by(auto intro: SUP_least) }
  then show ?case by(auto intro!: SUP_least split: generat.split)
qed

context includes lifting_syntax begin
lemma interaction_bound_parametric':
  assumes [transfer_rule]: "bi_total R"
  shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=)) interaction_bound interaction_bound"
unfolding interaction_bound_def[abs_def]
apply(rule rel_funI)
apply(rule fixp_lfp_parametric_eq[OF interaction_bound.mono interaction_bound.mono])
subgoal premises [transfer_rule]
  supply the_gpv_parametric'[transfer_rule] rel_gpv''_eq[relator_eq]
  by transfer_prover
done

lemma interaction_bound_parametric [transfer_rule]:
  "((C ===> (=)) ===> rel_gpv A C ===> (=)) interaction_bound interaction_bound"
unfolding rel_gpv_conv_rel_gpv'' by(rule interaction_bound_parametric')(rule bi_total_eq)
end

text ‹
  There is no nice @{const interaction_bound} equation for @{const bind_gpv}, as it computes
  an exact bound, but we only need an upper bound.
  As @{typ enat} is hard to work with (and @{term } does not constrain a gpv in any way),
  we work with @{typ nat}.
›

inductive interaction_bounded_by :: "('out  bool)  ('a, 'out, 'in) gpv  enat  bool"
for "consider" gpv n where
  interaction_bounded_by: " interaction_bound consider gpv  n   interaction_bounded_by consider gpv n"

lemmas interaction_bounded_byI = interaction_bounded_by
hide_fact (open) interaction_bounded_by

context includes lifting_syntax begin
lemma interaction_bounded_by_parametric [transfer_rule]:
  "((C ===> (=)) ===> rel_gpv A C ===> (=) ===> (=)) interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover

lemma interaction_bounded_by_parametric':
  notes interaction_bound_parametric'[transfer_rule]
  assumes [transfer_rule]: "bi_total R"
  shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=) ===> (=)) 
         interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover
end

lemma interaction_bounded_by_mono:
  " interaction_bounded_by consider gpv n; n  m   interaction_bounded_by consider gpv m"
unfolding interaction_bounded_by.simps by(erule order_trans) simp

lemma interaction_bounded_by_contD:
  " interaction_bounded_by consider gpv n; IO out c  set_spmf (the_gpv gpv); consider out 
   n > 0  interaction_bounded_by consider (c input) (n - 1)"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec)

lemma interaction_bounded_by_contD_ignore:
  " interaction_bounded_by consider gpv n; IO out c  set_spmf (the_gpv gpv) 
   interaction_bounded_by consider (c input) n"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto 4 4 simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec split: if_split_asm elim: order_trans)

lemma interaction_bounded_byI_epred:
  assumes "out c.  IO out c  set_spmf (the_gpv gpv); consider out   n  0  (input. interaction_bounded_by consider (c input) (n - 1))"
  and "out c input.  IO out c  set_spmf (the_gpv gpv); ¬ consider out   interaction_bounded_by consider (c input) n"
  shows "interaction_bounded_by consider gpv n"
unfolding interaction_bounded_by.simps
by(subst interaction_bound.simps)(auto 4 5 intro!: SUP_least split: generat.split dest: assms simp add: eSuc_le_iff enat_eSuc_iff gr0_conv_Suc neq_zero_conv_eSuc interaction_bounded_by.simps)

lemma interaction_bounded_by_IO:
  " IO out c  set_spmf (the_gpv gpv); interaction_bounded_by consider gpv n; consider out 
   n  0  interaction_bounded_by consider (c input) (n - 1)"
by(drule interaction_bound_IO[where input=input and ?consider="consider"])(auto simp add: interaction_bounded_by.simps epred_conv_minus eSuc_le_iff enat_eSuc_iff)

lemma interaction_bounded_by_0: "interaction_bounded_by consider gpv 0  interaction_bound consider gpv = 0"
by(simp add: interaction_bounded_by.simps zero_enat_def[symmetric])

abbreviation interaction_bounded_by' :: "('out  bool)  ('a, 'out, 'in) gpv  nat  bool"
where "interaction_bounded_by' consider gpv n  interaction_bounded_by consider gpv (enat n)"

named_theorems interaction_bound

lemmas interaction_bounded_by_start = interaction_bounded_by_mono

method interaction_bound_start = (rule interaction_bounded_by_start)
method interaction_bound_step uses add simp =
  ((match conclusion in "interaction_bounded_by _ _ _"  fail ¦ _  solvesclarsimp simp add: simp››) | rule add interaction_bound)
method interaction_bound_rec uses add simp = 
  (interaction_bound_step add: add simp: simp; (interaction_bound_rec add: add simp: simp)?)
method interaction_bound uses add simp =
  ((* use in *) interaction_bound_start, interaction_bound_rec add: add simp: simp)

lemma interaction_bounded_by_Done [simp]: "interaction_bounded_by consider (Done x) n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_DoneI [interaction_bound]:
  "interaction_bounded_by consider (Done x) 0"
by simp

lemma interaction_bounded_by_Fail [simp]: "interaction_bounded_by consider Fail n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_FailI [interaction_bound]: "interaction_bounded_by consider Fail 0"
by simp

lemma interaction_bounded_by_lift_spmf [simp]: "interaction_bounded_by consider (lift_spmf p) n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_lift_spmfI [interaction_bound]:
  "interaction_bounded_by consider (lift_spmf p) 0"
by simp

lemma interaction_bounded_by_assert_gpv [simp]: "interaction_bounded_by consider (assert_gpv b) n"
by(cases b) simp_all

lemma interaction_bounded_by_assert_gpvI [interaction_bound]:
  "interaction_bounded_by consider (assert_gpv b) 0"
by simp

lemma interaction_bounded_by_Pause [simp]:
  "interaction_bounded_by consider (Pause out c) n  
  (if consider out then 0 < n  (input. interaction_bounded_by consider (c input) (n - 1)) else (input. interaction_bounded_by consider (c input) n))"
by(cases n rule: co.enat.exhaust)
  (auto 4 3 simp add: interaction_bounded_by.simps eSuc_le_iff enat_eSuc_iff gr0_conv_Suc intro: SUP_least dest: order_trans[OF SUP_upper, rotated])

lemma interaction_bounded_by_PauseI [interaction_bound]:
  "(input. interaction_bounded_by consider (c input) (n input))
   interaction_bounded_by consider (Pause out c) (if consider out then 1 + (SUP input. n input) else (SUP input. n input))"
by(auto simp add: iadd_is_0 enat_add_sub_same intro: interaction_bounded_by_mono SUP_upper)

lemma interaction_bounded_by_bindI [interaction_bound]:
  " interaction_bounded_by consider gpv n; x. x  results'_gpv gpv  interaction_bounded_by consider (f x) (m x) 
   interaction_bounded_by consider (gpv  f) (n + (SUP xresults'_gpv gpv. m x))"
unfolding interaction_bounded_by.simps plus_enat_simps(1)[symmetric]
by(rule interaction_bound_bind[THEN order_trans])(auto intro: add_mono SUP_mono)

lemma interaction_bounded_by_bind_PauseI [interaction_bound]:
  "(input. interaction_bounded_by consider (c input  f) (n input))
   interaction_bounded_by consider (Pause out c  f) (if consider out then SUP input. n input + 1 else SUP input. n input)"
by(auto 4 3 simp add: interaction_bounded_by.simps SUP_enat_add_left eSuc_plus_1 intro: SUP_least SUP_upper2)

lemma interaction_bounded_by_bind_lift_spmf [simp]:
  "interaction_bounded_by consider (lift_spmf p  f) n  (xset_spmf p. interaction_bounded_by consider (f x) n)"
by(simp add: interaction_bounded_by.simps SUP_le_iff)

lemma interaction_bounded_by_bind_lift_spmfI [interaction_bound]:
  "(x. x  set_spmf p  interaction_bounded_by consider (f x) (n x))
   interaction_bounded_by consider (lift_spmf p  f) (SUP xset_spmf p. n x)"
by(auto intro: interaction_bounded_by_mono SUP_upper)

lemma interaction_bounded_by_bind_DoneI [interaction_bound]:
  "interaction_bounded_by consider (f x) n  interaction_bounded_by consider (Done x  f) n"
by(simp)

lemma interaction_bounded_by_if [interaction_bound]:
  " b  interaction_bounded_by consider gpv1 n; ¬ b  interaction_bounded_by consider gpv2 m 
   interaction_bounded_by consider (if b then gpv1 else gpv2) (if b then n else m)"
by(auto 4 3 simp add: max_def not_le elim: interaction_bounded_by_mono)

lemma interaction_bounded_by_case_bool [interaction_bound]:
  " b  interaction_bounded_by consider t bt; ¬ b  interaction_bounded_by consider f bf 
   interaction_bounded_by consider (case_bool t f b) (if b then bt else bf)"
by(cases b)(auto)

lemma interaction_bounded_by_case_sum [interaction_bound]:
  " y. x = Inl y  interaction_bounded_by consider (l y) (bl y);
     y. x = Inr y  interaction_bounded_by consider (r y) (br y) 
   interaction_bounded_by consider (case_sum l r x) (case_sum bl br x)"
by(cases x)(auto)

lemma interaction_bounded_by_case_prod [interaction_bound]:
  "(a b. x = (a, b)  interaction_bounded_by consider (f a b) (n a b))
   interaction_bounded_by consider (case_prod f x) (case_prod n x)"
by(simp split: prod.split)

lemma interaction_bounded_by_let [interaction_bound]: ― ‹This rule unfolds let's›
  "interaction_bounded_by consider (f t) m  interaction_bounded_by consider (Let t f) m"
by(simp add: Let_def)

lemma interaction_bounded_by_map_gpv_id [interaction_bound]:
  assumes [interaction_bound]: "interaction_bounded_by P gpv n"
  shows "interaction_bounded_by P (map_gpv f id gpv) n"
unfolding id_def map_gpv_conv_bind by interaction_bound simp

abbreviation interaction_any_bounded_by :: "('a, 'out, 'in) gpv  enat  bool"
where "interaction_any_bounded_by  interaction_bounded_by (λ_. True)"

lemma interaction_any_bounded_by_map_gpv':
  assumes "interaction_any_bounded_by gpv n"
    and "surj h"
  shows "interaction_any_bounded_by (map_gpv' f g h gpv) n"
  using assms by(simp add: interaction_bounded_by.simps interaction_bound_map_gpv' o_def)

subsection ‹Typing›

subsubsection ‹Interface between gpvs and rpvs / callees›

lemma is_empty_parametric [transfer_rule]: "rel_fun (rel_set A) (=) Set.is_empty Set.is_empty" (* Move *)
by(auto simp add: rel_fun_def Set.is_empty_def dest: rel_setD1 rel_setD2)

typedef ('call, 'ret)= "UNIV :: ('call  'ret set) set" ..

setup_lifting type_definition_ℐ

lemma outs_ℐ_tparametric:
  includes lifting_syntax 
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> rel_set B) ===> rel_set A) (λresps. {out. resps out  {}}) (λresps. {out. resps out  {}})"
  by(fold Set.is_empty_def) transfer_prover

lift_definition outs_ℐ :: "('call, 'ret) 'call set" is "λresps. {out. resps out  {}}" parametric outs_ℐ_tparametric .
lift_definition responses_ℐ :: "('call, 'ret) 'call  'ret set" is "λx. x" parametric id_transfer[unfolded id_def] .

lift_definition rel_ℐ :: "('call  'call'  bool)  ('ret  'ret'  bool)  ('call, 'ret) ('call', 'ret') bool"
is "λC R resp1 resp2. rel_set C {out. resp1 out  {}} {out. resp2 out  {}}  rel_fun C (rel_set R) resp1 resp2"
.

lemma rel_ℐI [intro?]:
  " rel_set C (outs_ℐ ℐ1) (outs_ℐ ℐ2); x y. C x y  rel_set R (responses_ℐ ℐ1 x) (responses_ℐ ℐ2 y) 
   rel_ℐ C R ℐ1 ℐ2"
by transfer(auto simp add: rel_fun_def)

lemma rel_ℐ_eq [relator_eq]: "rel_ℐ (=) (=) = (=)"
unfolding fun_eq_iff by transfer(auto simp add: relator_eq)

lemma rel_ℐ_conversep [simp]: "rel_ℐ C¯¯ R¯¯ = (rel_ℐ C R)¯¯"
unfolding fun_eq_iff conversep_iff
apply transfer
apply(rewrite in "rel_fun " conversep_iff[symmetric])
apply(rewrite in "rel_set " conversep_iff[symmetric])
apply(rewrite in "rel_fun _ " conversep_iff[symmetric])
apply(simp del: conversep_iff add: rel_fun_conversep)
apply(simp)
done

lemma rel_ℐ_conversep1_eq [simp]: "rel_ℐ C¯¯ (=) = (rel_ℐ C (=))¯¯"
by(rewrite in " = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma rel_ℐ_conversep2_eq [simp]: "rel_ℐ (=) R¯¯ = (rel_ℐ (=) R)¯¯"
by(rewrite in " = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma responses_ℐ_empty_iff: "responses_ℐ  out = {}  out  outs_ℐ "
including ℐ.lifting by transfer auto

lemma in_outs_ℐ_iff_responses_ℐ: "out  outs_ℐ   responses_ℐ  out  {}"
by(simp add: responses_ℐ_empty_iff)

lift_definition ℐ_full :: "('call, 'ret) ℐ" is "λ_. UNIV" .

lemma ℐ_full_sel [simp]:
  shows outs_ℐ_full: "outs_ℐ ℐ_full = UNIV"
  and responses_ℐ_full: "responses_ℐ ℐ_full x = UNIV"
by(transfer; simp; fail)+

context includes lifting_syntax begin
lemma outs_ℐ_parametric [transfer_rule]: "(rel_ℐ C R ===> rel_set C) outs_ℐ outs_ℐ"
unfolding rel_fun_def by transfer simp

lemma responses_ℐ_parametric [transfer_rule]: 
  "(rel_ℐ C R ===> C ===> rel_set R) responses_ℐ responses_ℐ"
unfolding rel_fun_def by transfer(auto dest: rel_funD)

end

definition ℐ_trivial :: "('out, 'in) bool"
where "ℐ_trivial   outs_ℐ  = UNIV"

lemma ℐ_trivialI [intro?]: "(x. x  outs_ℐ )  ℐ_trivial "
by(auto simp add: ℐ_trivial_def)

lemma ℐ_trivialD: "ℐ_trivial   outs_ℐ  = UNIV"
by(simp add: ℐ_trivial_def)

lemma ℐ_trivial_ℐ_full [simp]: "ℐ_trivial ℐ_full"
by(simp add: ℐ_trivial_def)

lifting_update ℐ.lifting
lifting_forget ℐ.lifting

context includes ℐ.lifting begin

lift_definition ℐ_uniform :: "'out set  'in set  ('out, 'in) ℐ" is "λA B x. if x  A then B else {}" .

lemma outs_ℐ_uniform [simp]: "outs_ℐ (ℐ_uniform A B) = (if B = {} then {} else A)"
  by transfer simp

lemma responses_ℐ_uniform [simp]: "responses_ℐ (ℐ_uniform A B) x = (if x  A then B else {})"
  by transfer simp

lemma ℐ_uniform_UNIV [simp]: "ℐ_uniform UNIV UNIV = ℐ_full" (* TODO: make ℐ_full an abbreviation *)
  by transfer simp

lift_definition map_ℐ :: "('out'  'out)  ('in  'in')  ('out, 'in) ('out', 'in') ℐ"
  is "λf g resp x. g ` resp (f x)" .

lemma outs_ℐ_map_ℐ [simp]:
  "outs_ℐ (map_ℐ f g ) = f -` outs_ℐ "
  by transfer simp

lemma responses_ℐ_map_ℐ [simp]:
  "responses_ℐ (map_ℐ f g ) x = g ` responses_ℐ  (f x)"
  by transfer simp

lemma map_ℐ_ℐ_uniform [simp]:
  "map_ℐ f g (ℐ_uniform A B) = ℐ_uniform (f -` A) (g ` B)"
  by transfer(auto simp add: fun_eq_iff)

lemma map_ℐ_id [simp]: "map_ℐ id id  = "
  by transfer simp

lemma map_ℐ_id0: "map_ℐ id id = id"
  by(simp add: fun_eq_iff)

lemma map_ℐ_comp [simp]: "map_ℐ f g (map_ℐ f' g' ) = map_ℐ (f'  f) (g  g') "
  by transfer auto

lemma map_ℐ_cong: "map_ℐ f g  = map_ℐ f' g' ℐ'"
  if " = ℐ'" and f: "f = f'" and "x y.  x  outs_ℐ ℐ'; y  responses_ℐ ℐ' x   g y = g' y"
  unfolding that(1,2) using that(3-)
  by transfer(auto simp add: fun_eq_iff intro!: image_cong)

lifting_update ℐ.lifting
lifting_forget ℐ.lifting
end

functor map_ℐ by(simp_all add: fun_eq_iff)

lemma ℐ_eqI: " outs_ℐ  = outs_ℐ ℐ'; x. x  outs_ℐ ℐ'  responses_ℐ  x = responses_ℐ ℐ' x    = ℐ'"
  including ℐ.lifting by transfer auto

instantiation:: (type, type) order begin

definition less_eq_ℐ :: "('a, 'b) ('a, 'b) bool"
  where le_ℐ_def: "less_eq_ℐ  ℐ'  outs_ℐ   outs_ℐ ℐ'  (xouts_ℐ . responses_ℐ ℐ' x  responses_ℐ  x)"

definition less_ℐ :: "('a, 'b) ('a, 'b) bool"
  where "less_ℐ = mk_less (≤)"

instance
proof
  show " < ℐ'    ℐ'  ¬ ℐ'  " for  ℐ' :: "('a, 'b) ℐ" by(simp add: less_ℐ_def mk_less_def)
  show "  " for  :: "('a, 'b) ℐ" by(simp add: le_ℐ_def)
  show "  ℐ''" if "  ℐ'" "ℐ'  ℐ''" for  ℐ' ℐ'' :: "('a, 'b) ℐ" using that
    by(fastforce simp add: le_ℐ_def)
  show " = ℐ'" if "  ℐ'" "ℐ'  " for  ℐ' :: "('a, 'b) ℐ" using that
    by(auto simp add: le_ℐ_def intro!: ℐ_eqI)
qed
end

instantiation:: (type, type) order_bot begin
definition bot_ℐ :: "('a, 'b) ℐ" where "bot_ℐ = ℐ_uniform {} UNIV"
instance by standard(auto simp add: bot_ℐ_def le_ℐ_def)
end

lemma outs_ℐ_bot [simp]: "outs_ℐ bot = {}"
  by(simp add: bot_ℐ_def)

lemma respones_ℐ_bot [simp]: "responses_ℐ bot x = {}"
  by(simp add: bot_ℐ_def)

lemma outs_ℐ_mono: "  ℐ'  outs_ℐ   outs_ℐ ℐ'"
  by(simp add: le_ℐ_def)

lemma responses_ℐ_mono: "   ℐ'; x  outs_ℐ    responses_ℐ ℐ' x  responses_ℐ  x"
  by(simp add: le_ℐ_def)

lemma ℐ_uniform_empty [simp]: "ℐ_uniform {} A = bot" 
  unfolding bot_ℐ_def including ℐ.lifting by transfer simp

lemma ℐ_uniform_mono:
  "ℐ_uniform A B  ℐ_uniform C D" if "A  C" "D  B" "D = {}  B = {}"
  unfolding le_ℐ_def using that by auto


context begin
qualified inductive resultsp_gpv :: "('out, 'in) 'a  ('a, 'out, 'in) gpv  bool"
  for Γ x
where
  Pure: "Pure x  set_spmf (the_gpv gpv)  resultsp_gpv Γ x gpv"
| IO:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ Γ out; resultsp_gpv Γ x (c input) 
   resultsp_gpv Γ x gpv"

definition results_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  'a set"
where "results_gpv Γ gpv  {x. resultsp_gpv Γ x gpv}"

lemma resultsp_gpv_results_gpv_eq [pred_set_conv]: "resultsp_gpv Γ x gpv  x  results_gpv Γ gpv"
by(simp add: results_gpv_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "results_gpv")

lemmas intros [intro?] = resultsp_gpv.intros[to_set]
  and Pure = Pure[to_set]
  and IO = IO[to_set]
  and induct [consumes 1, case_names Pure IO, induct set: results_gpv] = resultsp_gpv.induct[to_set]
  and cases [consumes 1, case_names Pure IO, cases set: results_gpv] = resultsp_gpv.cases[to_set]
  and simps = resultsp_gpv.simps[to_set]
end

inductive_simps results_gpv_GPV [to_set, simp]: "resultsp_gpv Γ x (GPV gpv)"

end

lemma results_gpv_Done [iff]: "results_gpv Γ (Done x) = {x}"
by(auto simp add: Done.ctr)

lemma results_gpv_Fail [iff]: "results_gpv Γ Fail = {}"
by(auto simp add: Fail_def)

lemma results_gpv_Pause [simp]:
  "results_gpv Γ (Pause out c) = (inputresponses_ℐ Γ out. results_gpv Γ (c input))"
by(auto simp add: Pause.ctr)

lemma results_gpv_lift_spmf [iff]: "results_gpv Γ (lift_spmf p) = set_spmf p"
by(auto simp add: lift_spmf.ctr)

lemma results_gpv_assert_gpv [simp]: "results_gpv Γ (assert_gpv b) = (if b then {()} else {})"
by auto

lemma results_gpv_bind_gpv [simp]:
  "results_gpv Γ (gpv  f) = (xresults_gpv Γ gpv. results_gpv Γ (f x))"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  fix x
  assume "x  ?lhs"
  then show "x  ?rhs"
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case Pure thus ?case
      by(auto 4 3 split: if_split_asm intro: results_gpv.intros rev_bexI)
  next
    case (IO out c input)
    from ‹IO out c  _
    obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
      and *: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                                   else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
      by(auto)
    thus ?case
    proof(cases generat)
      case (Pure y)
      with generat have "y  results_gpv Γ gpv" by(auto intro: results_gpv.intros)
      thus ?thesis using * Pure input  responses_ℐ Γ out x  results_gpv Γ (c input)
        by(auto intro: results_gpv.IO)
    next
      case (IO out' c')
      hence [simp]: "out' = out"
        and c: "input. c input = bind_gpv (c' input) f" using * by simp_all
      from IO.hyps(4)[OF c] obtain y where y: "y  results_gpv Γ (c' input)"
        and "x  results_gpv Γ (f y)" by blast
      from y IO generat have "y  results_gpv Γ gpv" using input  responses_ℐ Γ out
        by(auto intro: results_gpv.IO)
      with x  results_gpv Γ (f y) show ?thesis by blast
    qed
  qed
next
  fix x
  assume "x  ?rhs"
  then obtain y where y: "y  results_gpv Γ gpv"
    and x: "x  results_gpv Γ (f y)" by blast
  from y show "x  ?lhs"
  proof(induction)
    case (Pure gpv)
    with x show ?case
      by cases(auto 4 4 intro: results_gpv.intros rev_bexI)
  qed(auto 4 4 intro: rev_bexI results_gpv.IO)
qed

lemma results_gpv_ℐ_full: "results_gpv ℐ_full = results'_gpv"
proof(intro ext set_eqI iffI)
  show "x  results'_gpv gpv" if "x  results_gpv ℐ_full gpv" for x gpv
    using that by induction(auto intro: results'_gpvI)
  show "x  results_gpv ℐ_full gpv" if "x  results'_gpv gpv" for x gpv
    using that by induction(auto intro: results_gpv.intros elim!: generat.set_cases)
qed

lemma results'_bind_gpv [simp]:
  "results'_gpv (bind_gpv gpv f) = (xresults'_gpv gpv. results'_gpv (f x))"
unfolding results_gpv_ℐ_full[symmetric] by simp

lemma results_gpv_map_gpv_id [simp]: "results_gpv  (map_gpv f id gpv) = f ` results_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma results_gpv_map_gpv_id' [simp]: "results_gpv  (map_gpv f (λx. x) gpv) = f ` results_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma pred_gpv_bind [simp]: "pred_gpv P Q (bind_gpv gpv f) = pred_gpv (pred_gpv P Q  f) Q gpv"
by(auto simp add: pred_gpv_def outs_bind_gpv)

lemma results'_gpv_bind_option [simp]:
  "results'_gpv (monad.bind_option Fail x f) = (yset_option x. results'_gpv (f y))"
by(cases x) simp_all

lemma results'_gpv_map_gpv':
  assumes "surj h"
  shows "results'_gpv (map_gpv' f g h gpv) = f ` results'_gpv gpv" (is "?lhs = ?rhs")
proof -
  have *:"IO z c  set_spmf (the_gpv gpv)  x  results'_gpv (c input) 
     f x  results'_gpv (map_gpv' f g h (c input))  f x  results'_gpv (map_gpv' f g h gpv)" for x z gpv c input
    using surjD[OF assms, of input] by(fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def)

  show ?thesis 
  proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
    show "x  ?rhs" if "x  ?lhs" for x using that
      by(induction gpv'"map_gpv' f g h gpv" arbitrary: gpv)(fastforce elim!: generat.set_cases intro: results'_gpvI)+
    show "f x  ?lhs" if "x  results'_gpv gpv" for x using that
      by induction (fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def
          , clarsimp simp add: *  elim!: generat.set_cases)
  qed
qed

lemma bind_gpv_bind_option_assoc:
  "bind_gpv (monad.bind_option Fail x f) g = monad.bind_option Fail x (λx. bind_gpv (f x) g)"
by(cases x) simp_all

context begin
qualified inductive outsp_gpv :: "('out, 'in) 'out  ('a, 'out, 'in) gpv  bool"
  for  x where
    IO: "IO x c  set_spmf (the_gpv gpv)  outsp_gpv  x gpv"
  | Cont: " IO out rpv  set_spmf (the_gpv gpv); input  responses_ℐ  out; outsp_gpv  x (rpv input) 
   outsp_gpv  x gpv"

definition outs_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  'out set"
  where "outs_gpv  gpv  {x. outsp_gpv  x gpv}"

lemma outsp_gpv_outs_gpv_eq [pred_set_conv]: "outsp_gpv  x = (λgpv. x  outs_gpv  gpv)"
  by(simp add: outs_gpv_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "outs_gpv")

lemmas intros [intro?] = outsp_gpv.intros[to_set]
  and IO = IO[to_set]
  and Cont = Cont[to_set]
  and induct [consumes 1, case_names IO Cont, induct set: outs_gpv] = outsp_gpv.induct[to_set]
  and cases [consumes 1, case_names IO Cont, cases set: outs_gpv] = outsp_gpv.cases[to_set]
  and simps = outsp_gpv.simps[to_set]
end

inductive_simps outs_gpv_GPV [to_set, simp]: "outsp_gpv  x (GPV gpv)"

end

lemma outs_gpv_Done [iff]: "outs_gpv  (Done x) = {}"
  by(auto simp add: Done.ctr)

lemma outs_gpv_Fail [iff]: "outs_gpv  Fail = {}"
  by(auto simp add: Fail_def)

lemma outs_gpv_Pause [simp]:
  "outs_gpv  (Pause out c) = insert out (inputresponses_ℐ  out. outs_gpv  (c input))"
  by(auto simp add: Pause.ctr)

lemma outs_gpv_lift_spmf [iff]: "outs_gpv  (lift_spmf p) = {}"
  by(auto simp add: lift_spmf.ctr)

lemma outs_gpv_assert_gpv [simp]: "outs_gpv  (assert_gpv b) = {}"
  by(cases b)auto

lemma outs_gpv_bind_gpv [simp]:
  "outs_gpv  (gpv  f) = outs_gpv  gpv  (xresults_gpv  gpv. outs_gpv  (f x))"
  (is "?lhs = ?rhs")
proof(intro Set.set_eqI iffI)
  fix x
  assume "x  ?lhs"
  then show "x  ?rhs"
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case IO thus ?case
    proof(clarsimp split: if_split_asm elim!: is_PureE not_is_PureE, goal_cases)
      case (1 generat)
      then show ?case by(cases generat)(auto intro: results_gpv.Pure outs_gpv.intros)
    qed
  next
    case (Cont out rpv input)
    thus ?case
    proof(clarsimp split: if_split_asm, goal_cases)
      case (1 generat)
      then show ?case by(cases generat)(auto 4 3 split: if_split_asm intro: results_gpv.intros outs_gpv.intros)
    qed
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (out) "x  outs_gpv  gpv" | (result) y where "y  results_gpv  gpv" "x  outs_gpv  (f y)" by auto
  then show "x  ?lhs"
  proof cases
    case out then show ?thesis
      by(induction) (auto 4 4 intro: outs_gpv.IO  outs_gpv.Cont rev_bexI) 
  next
    case result then show ?thesis
      by induction ((erule outs_gpv.cases | rule outs_gpv.Cont), 
          auto 4 4 intro: outs_gpv.intros rev_bexI elim: outs_gpv.cases)+
  qed
qed

lemma outs_gpv_ℐ_full: "outs_gpv ℐ_full = outs'_gpv"
proof(intro ext Set.set_eqI iffI)
  show "x  outs'_gpv gpv" if "x  outs_gpv ℐ_full gpv" for x gpv
    using that by induction(auto intro: outs'_gpvI)
  show "x  outs_gpv ℐ_full gpv" if "x  outs'_gpv gpv" for x gpv
    using that by induction(auto intro: outs_gpv.intros elim!: generat.set_cases)
qed

lemma outs'_bind_gpv [simp]:
  "outs'_gpv (bind_gpv gpv f) = outs'_gpv gpv  (xresults'_gpv gpv. outs'_gpv (f x))"
  unfolding outs_gpv_ℐ_full[symmetric] results_gpv_ℐ_full[symmetric] by simp

lemma outs_gpv_map_gpv_id [simp]: "outs_gpv  (map_gpv f id gpv) = outs_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma outs_gpv_map_gpv_id' [simp]: "outs_gpv  (map_gpv f (λx. x) gpv) = outs_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma outs'_gpv_bind_option [simp]:
  "outs'_gpv (monad.bind_option Fail x f) = (yset_option x. outs'_gpv (f y))"
  by(cases x) simp_all

lemma rel_gpv''_Grp: includes lifting_syntax shows
  "rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯ = 
   BNF_Def.Grp {x. results_gpv (ℐ_uniform UNIV (range h)) x  A  outs_gpv (ℐ_uniform UNIV (range h)) x  B} (map_gpv' f g h)"
  (is "?lhs = ?rhs")
proof(intro ext GrpI iffI CollectI conjI subsetI)
  let ?ℐ = "ℐ_uniform UNIV (range h)"
  fix gpv gpv'
  assume *: "?lhs gpv gpv'"
  then show "map_gpv' f g h gpv = gpv'"
    by(coinduction arbitrary: gpv gpv')
      (drule rel_gpv''D
        , auto 4 5 simp add: spmf_rel_map generat.rel_map elim!: rel_spmf_mono generat.rel_mono_strong GrpE intro!: GrpI dest: rel_funD)
  show "x  A" if "x  results_gpv ?ℐ gpv" for x using that *
  proof(induction arbitrary: gpv')
    case (Pure gpv)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using Pure.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] ..
    with Pure.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
  next
    case (IO out c gpv input)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with IO.hyps show ?case 
      by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: IO.IH dest!: bspec)
  qed
  show "x  B" if "x  outs_gpv ?ℐ gpv" for x using that *
  proof(induction arbitrary: gpv')
    case (IO c gpv)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with IO.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
  next
    case (Cont out rpv gpv input)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using Cont.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with Cont.hyps show ?case 
      by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: Cont.IH dest!: bspec)
  qed
next
  fix gpv gpv'
  assume "?rhs gpv gpv'"
  then have gpv': "gpv' = map_gpv' f g h gpv"
    and *: "results_gpv (ℐ_uniform UNIV (range h)) gpv  A" "outs_gpv (ℐ_uniform UNIV (range h)) gpv  B" by(auto simp add: Grp_iff)
  show "?lhs gpv gpv'" using * unfolding gpv'
    by(coinduction arbitrary: gpv)
      (fastforce simp add: spmf_rel_map generat.rel_map Grp_iff intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI elim!: generat.set_cases intro: results_gpv.intros outs_gpv.intros)
qed

inductive pred_gpv' :: "('a  bool)  ('out  bool)  'in set  ('a, 'out, 'in) gpv  bool" for P Q X gpv where
  "pred_gpv' P Q X gpv" 
if "x. x  results_gpv (ℐ_uniform UNIV X) gpv  P x" "out. out  outs_gpv (ℐ_uniform UNIV X) gpv  Q out"

lemma pred_gpv_conv_pred_gpv': "pred_gpv P Q = pred_gpv' P Q UNIV"
  by(auto simp add: fun_eq_iff pred_gpv_def pred_gpv'.simps results_gpv_ℐ_full outs_gpv_ℐ_full)

lemma rel_gpv''_map_gpv'1:
  "rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv gpv'  rel_gpv'' A C (=) (map_gpv' id id h gpv) gpv'"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(simp add: spmf_rel_map)
  apply(erule rel_spmf_mono)
  apply(simp add: generat.rel_map)
  apply(erule generat.rel_mono_strong; simp?)
  apply(subst map_fun2_id)
  by(auto simp add: rel_fun_comp intro!: rel_fun_map_fun1 elim: rel_fun_mono)

lemma rel_gpv''_map_gpv'2:
  "rel_gpv'' A C (eq_on (range h)) gpv gpv'  rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv (map_gpv' id id h gpv')"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(simp add: spmf_rel_map)
  apply(erule rel_spmf_mono_strong)
  apply(simp add: generat.rel_map)
  apply(erule generat.rel_mono_strong; simp?)
  apply(subst map_fun_id2_in)
  apply(rule rel_fun_map_fun2)
  by (auto simp add: rel_fun_comp  elim: rel_fun_mono)

context
  fixes A :: "'a  'd  bool"
    and C :: "'c  'g  bool"
    and R :: "'b  'e  bool"
begin

private lemma f11:" Pure x  set_spmf (the_gpv gpv) 
   Domainp (rel_generat A C (rel_fun R (rel_gpv'' A C R))) (Pure x)  Domainp A x"
  by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f21: "IO out c  set_spmf (the_gpv gpv)  
  rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) ba  Domainp C out"
  by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f12:
  assumes "IO out c  set_spmf (the_gpv gpv)"
    and "input  responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out"
    and "x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) (c input)"
    and "Domainp (rel_gpv'' A C R) gpv"
  shows "Domainp (rel_gpv'' A C R) (c input)"
proof -
  obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by clarsimp
  obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) b2"
    using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_spmf_def by - (drule (1) bspec, auto)

  have "Ball (generat_conts (IO out c)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
    using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_generat_def by simp

  with assms(2) show ?thesis 
    apply -
    apply(drule bspec)
     apply simp
    apply clarify
    apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
    by simp  
qed

private lemma f22:
  assumes "IO out' rpv  set_spmf (the_gpv gpv)"
    and "input  responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out'"
    and "out  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) (rpv input)"
    and "Domainp (rel_gpv'' A C R) gpv"
  shows "Domainp (rel_gpv'' A C R) (rpv input)"
proof -
  obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by auto
  obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out' rpv) b2"
    using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_spmf_def by - (drule (1) bspec, auto)

  have "Ball (generat_conts (IO out' rpv)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
    using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_generat_def by simp

  with assms(2) show ?thesis 
    apply -
    apply(drule bspec)
     apply simp
    apply clarify
    apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
    by simp 
qed

lemma Domainp_rel_gpv''_le:
  "Domainp (rel_gpv'' A C R)  pred_gpv' (Domainp A) (Domainp C) {x. Domainp R x}"
proof(rule predicate1I pred_gpv'.intros)+
  show "Domainp A x" if "x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for x gpv using that
  proof(induction)
    case (Pure gpv)
    then show ?case 
      by (clarify) (drule rel_gpv''D
          , auto simp add: f11 pred_spmf_def dest: spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
  qed (simp add: f12) 
  show "Domainp C out" if "out  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for out gpv using that
  proof( induction)
    case (IO c gpv)
    then show ?case
      by (clarify) (drule rel_gpv''D
          , auto simp add: f21 pred_spmf_def dest!: bspec spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
  qed (simp add: f22)
qed

end

lemma map_gpv'_id12: "map_gpv' f g h gpv = map_gpv' id id h (map_gpv f g gpv)"
  unfolding map_gpv_conv_map_gpv' map_gpv'_comp by simp

lemma rel_gpv''_refl: " (=)  A; (=)  C; R  (=)   (=)  rel_gpv'' A C R"
  by(subst rel_gpv''_eq[symmetric])(rule rel_gpv''_mono)


context
  fixes A A' :: "'a  'b  bool"
    and C C' :: "'c  'd  bool"
    and R R' :: "'e  'f  bool"
   
begin

private abbreviation foo where 
  "foo  (λ fx fy gpvx gpvy g1 g2. 
            x y. x  fx (ℐ_uniform UNIV (Collect (Domainp R'))) gpvx 
                  y  fy (ℐ_uniform UNIV (Collect (Rangep R'))) gpvy  g1 x y  g2 x y)"

private lemma f1: "foo results_gpv results_gpv gpv gpv' A A' 
       x  set_spmf (the_gpv gpv)  y  set_spmf (the_gpv gpv') 
       a  generat_conts x  b  generat_conts y   R' a' α  R' β b'  
    foo results_gpv results_gpv (a a') (b b') A A'"
  by (fastforce elim: generat.set_cases intro: results_gpv.IO)

private lemma f2: "foo outs_gpv outs_gpv gpv gpv' C C' 
       x  set_spmf (the_gpv gpv)  y  set_spmf (the_gpv gpv') 
       a  generat_conts x  b  generat_conts y  R' a' α  R' β b'  
    foo outs_gpv outs_gpv (a a') (b b') C C'"
  by (fastforce elim: generat.set_cases intro: outs_gpv.Cont)

lemma rel_gpv''_mono_strong:
  " rel_gpv'' A C R gpv gpv'; 
     x y.  x  results_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y  results_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; A x y   A' x y;
     x y.  x  outs_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y  outs_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; C x y   C' x y;
     R'  R 
   rel_gpv'' A' C' R' gpv gpv'"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(erule rel_spmf_mono_strong)
  apply(erule generat.rel_mono_strong)
    apply(erule generat.set_cases)+
    apply(erule allE, rotate_tac -1)
    apply(erule allE)
    apply(erule impE)
     apply(rule results_gpv.Pure)
     apply simp
    apply(erule impE)
     apply(rule results_gpv.Pure)
     apply simp
    apply simp
   apply(erule generat.set_cases)+
   apply(rotate_tac 1)
   apply(erule allE, rotate_tac -1)
   apply(erule allE)
   apply(erule impE)
    apply(rule outs_gpv.IO)
    apply simp
   apply(erule impE)
    apply(rule outs_gpv.IO)
    apply simp
   apply simp
  apply(erule (1) rel_fun_mono_strong)
  by (fastforce simp add: f1[simplified] f2[simplified])

end

lemma rel_gpv''_refl_strong:
  assumes "x. x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv  A x x"
    and "x. x  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv  C x x"
    and "R  (=)"
  shows "rel_gpv'' A C R gpv gpv"
proof -
  have "rel_gpv'' (=) (=) (=) gpv gpv" unfolding rel_gpv''_eq by simp
  then show ?thesis using _ _ assms(3) by(rule rel_gpv''_mono_strong)(auto intro: assms(1-2))
qed

lemma rel_gpv''_refl_eq_on:
  " x. x  results_gpv (ℐ_uniform UNIV X) gpv  A x x; out. out  outs_gpv (ℐ_uniform UNIV X) gpv  B out out 
   rel_gpv'' A B (eq_on X) gpv gpv"
  by(rule rel_gpv''_refl_strong) (auto elim: eq_onE)

lemma pred_gpv'_mono' [mono]:
  "pred_gpv' A C R gpv  pred_gpv' A' C' R gpv"
  if "x. A x  A' x" "x. C x  C' x"
  using that unfolding pred_gpv'.simps
  by auto

subsubsection ‹Type judgements›

coinductive WT_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool" ("((_)/ ⊢g (_) )" [100, 0] 99)
  for Γ
where
  "(out c. IO out c  set_spmf gpv  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))
   Γ ⊢g GPV gpv "

lemma WT_gpv_coinduct [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont, coinduct pred: WT_gpv]:
  assumes *: "X gpv"
  and step: "gpv out c.
     X gpv; IO out c  set_spmf (the_gpv gpv) 
     out  outs_ℐ Γ  (input  responses_ℐ Γ out. X (c input)  Γ ⊢g c input )"
  shows "Γ ⊢g gpv "
using * by(coinduct)(auto dest: step simp add: eq_GPV_iff)

lemma WT_gpv_simps:
  "Γ ⊢g GPV gpv  
   (out c. IO out c  set_spmf gpv  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))"
by(subst WT_gpv.simps) simp

lemma WT_gpvI:
  "(out c. IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))
   Γ ⊢g gpv "
by(cases gpv)(simp add: WT_gpv_simps)

lemma WT_gpvD:
  assumes "Γ ⊢g gpv "
  shows WT_gpv_OutD: "IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ"
  and WT_gpv_ContD: " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ Γ out   Γ ⊢g c input "
using assms by(cases, fastforce)+

lemma WT_gpv_mono:
  assumes WT: "ℐ1 ⊢g gpv "
  and outs: "outs_ℐ ℐ1  outs_ℐ ℐ2"
  and responses: "x. x  outs_ℐ ℐ1  responses_ℐ ℐ2 x  responses_ℐ ℐ1 x"
  shows "ℐ2 ⊢g gpv "
using WT
proof coinduct
  case (WT_gpv gpv out c)
  with outs show ?case by(auto 6 4 dest: responses WT_gpvD)
qed

lemma WT_gpv_Done [iff]: "Γ ⊢g Done x "
by(rule WT_gpvI) simp_all

lemma WT_gpv_Fail [iff]: "Γ ⊢g Fail "
by(rule WT_gpvI) simp_all

lemma WT_gpv_PauseI: 
  " out  outs_ℐ Γ; input. input  responses_ℐ Γ out  Γ ⊢g c input  
    Γ ⊢g Pause out c "
by(rule WT_gpvI) simp_all

lemma WT_gpv_Pause [iff]:
  "Γ ⊢g Pause out c   out  outs_ℐ Γ  (input  responses_ℐ Γ out. Γ ⊢g c input )"
by(auto intro: WT_gpv_PauseI dest: WT_gpvD)

lemma WT_gpv_bindI:
  " Γ ⊢g gpv ; x. x  results_gpv Γ gpv  Γ ⊢g f x  
   Γ ⊢g gpv  f "
proof(coinduction arbitrary: gpv)
  case [rule_format]: (WT_gpv out c gpv)
  from ‹IO out c  _
  obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
    and *: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                                 else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
    by(auto)
  show ?case
  proof(cases generat)
    case (Pure y)
    with generat have "y  results_gpv Γ gpv" by(auto intro: results_gpv.Pure)
    hence "Γ ⊢g f y " by(rule WT_gpv)
    with * Pure show ?thesis by(auto dest: WT_gpvD) 
  next
    case (IO out' c')
    hence [simp]: "out' = out"
      and c: "input. c input = bind_gpv (c' input) f" using * by simp_all
    from generat IO have **: "IO out c'  set_spmf (the_gpv gpv)" by simp
    with Γ ⊢g gpv  have ?out by(auto dest: WT_gpvD)
    moreover {
      fix input
      assume input: "input  responses_ℐ Γ out"
      with Γ ⊢g gpv  ** have "Γ ⊢g c' input " by(rule WT_gpvD)
      moreover {
        fix y
        assume "y  results_gpv Γ (c' input)"
        with ** input have "y  results_gpv Γ gpv" by(rule results_gpv.IO)
        hence "Γ ⊢g f y " by(rule WT_gpv) }
      moreover note calculation }
    hence ?cont using c by blast
    ultimately show ?thesis ..
  qed
qed

lemma WT_gpv_bindD2:
  assumes WT: "Γ ⊢g gpv  f "
  and x: "x  results_gpv Γ gpv"
  shows "Γ ⊢g f x "
using x WT
proof induction
  case (Pure gpv)
  show ?case
  proof(rule WT_gpvI)
    fix out c
    assume "IO out c  set_spmf (the_gpv (f x))"
    with Pure have "IO out c  set_spmf (the_gpv (gpv  f))" by(auto intro: rev_bexI)
    with Γ ⊢g gpv  f  show "out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input )"
      by(auto dest: WT_gpvD simp del: set_bind_spmf)
  qed
next
  case (IO out c gpv input)
  from ‹IO out c  _
  have "IO out (λinput. bind_gpv (c input) f)  set_spmf (the_gpv (gpv  f))"
    by(auto intro: rev_bexI)
  with IO.prems have "Γ ⊢g c input  f " using input  _ by(rule WT_gpv_ContD)
  thus ?case by(rule IO.IH)
qed

lemma WT_gpv_bindD1: "Γ ⊢g gpv  f   Γ ⊢g gpv "
proof(coinduction arbitrary: gpv)
  case (WT_gpv out c gpv)
  from ‹IO out c  _
  have "IO out (λinput. bind_gpv (c input) f)  set_spmf (the_gpv (gpv  f))"
    by(auto intro: rev_bexI)
  with Γ ⊢g gpv  f  show ?case
    by(auto simp del: bind_gpv_sel' dest: WT_gpvD)
qed

lemma WT_gpv_bind [simp]: "Γ ⊢g gpv  f   Γ ⊢g gpv   (xresults_gpv Γ gpv. Γ ⊢g f x )"
by(blast intro: WT_gpv_bindI dest: WT_gpv_bindD1 WT_gpv_bindD2)

lemma WT_gpv_full [simp, intro!]: "ℐ_full ⊢g gpv "
by(coinduction arbitrary: gpv)(auto)

lemma WT_gpv_lift_spmf [simp, intro!]: " ⊢g lift_spmf p "
by(rule WT_gpvI) auto

lemma WT_gpv_coinduct_bind [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont]:
  assumes *: "X gpv"
  and step: "gpv out c.  X gpv; IO out c  set_spmf (the_gpv gpv) 
     out  outs_ℐ   (inputresponses_ℐ  out.
            X (c input) 
             ⊢g c input  
            ((gpv' :: ('b, 'call, 'ret) gpv) f. c input = gpv'  f   ⊢g gpv'   (xresults_gpv  gpv'. X (f x))))"
  shows " ⊢g gpv "
proof -
  fix x
  define gpv' :: "('b, 'call, 'ret) gpv" and f :: "'b  ('a, 'call, 'ret) gpv"
    where "gpv' = Done x" and "f = (λ_. gpv)"
  with * have " ⊢g gpv' " and "x. x  results_gpv  gpv'  X (f x)" by simp_all
  then have " ⊢g gpv'  f "
  proof(coinduction arbitrary: gpv' f rule: WT_gpv_coinduct)
    case [rule_format]: (WT_gpv out c gpv')
    from ‹IO out c  _
    obtain generat where generat: "generat  set_spmf (the_gpv gpv')"
      and *: "IO out c  set_spmf (if is_Pure generat
        then the_gpv (f (result generat))
        else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
      by(clarsimp)
    show ?case
    proof(cases generat)
      case (Pure x)
      from Pure * have IO: "IO out c  set_spmf (the_gpv (f x))" by simp
      from generat Pure have "x  results_gpv  gpv'" by (simp add: results_gpv.Pure)
      then have "X (f x)" by(rule WT_gpv)
      from step[OF this IO] show ?thesis by(auto 4 4 intro: exI[where x="Done _"])
    next
      case (IO out' c')
      with * have [simp]: "out' = out"
        and c: "c = (λinput. c' input  f)" by simp_all
      from IO generat have IO: "IO out c'  set_spmf (the_gpv gpv')" by simp
      then have "input. input  responses_ℐ  out  results_gpv  (c' input)  results_gpv  gpv'"
        by(auto intro: results_gpv.IO)
      with WT_gpvD[OF  ⊢g gpv'  IO] show ?thesis unfolding c using WT_gpv(2) by blast
    qed
  qed
  then show ?thesis unfolding gpv'_def f_def by simp
qed

lemma ℐ_trivial_WT_gpvD [simp]: "ℐ_trivial    ⊢g gpv "
using WT_gpv_full by(rule WT_gpv_mono)(simp_all add: ℐ_trivial_def)

lemma ℐ_trivial_WT_gpvI: 
  assumes "gpv :: ('a, 'out, 'in) gpv.  ⊢g gpv "
  shows "ℐ_trivial "
proof
  fix x
  have " ⊢g Pause x (λ_. Fail :: ('a, 'out, 'in) gpv) " by(rule assms)
  thus "x  outs_ℐ " by(simp)
qed

lemma WT_gpv_ℐ_mono: "  ⊢g gpv ;   ℐ'   ℐ' ⊢g gpv "
  by(erule WT_gpv_mono; rule outs_ℐ_mono responses_ℐ_mono)

lemma results_gpv_mono:
  assumes le: "ℐ'  " and WT: "ℐ' ⊢g gpv "
  shows "results_gpv  gpv  results_gpv ℐ' gpv"
proof(rule subsetI, goal_cases)
  case (1 x)
  show ?case using 1 WT by(induction)(auto 4 3 intro: results_gpv.intros responses_ℐ_mono[OF le, THEN subsetD] intro: WT_gpvD)
qed

lemma WT_gpv_outs_gpv:
  assumes " ⊢g gpv "
  shows "outs_gpv  gpv  outs_ℐ "
proof
  show "x  outs_ℐ " if "x  outs_gpv  gpv" for x using that assms
    by(induction)(blast intro: WT_gpv_OutD WT_gpv_ContD)+
qed

lemma WT_gpv_map_gpv': " ⊢g map_gpv' f g h gpv " if "map_ℐ g h  ⊢g gpv "
  using that by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma WT_gpv_map_gpv: " ⊢g map_gpv f g gpv " if "map_ℐ g id  ⊢g gpv "
  unfolding map_gpv_conv_map_gpv' using that by(rule WT_gpv_map_gpv')

lemma results_gpv_map_gpv' [simp]:
  "results_gpv  (map_gpv' f g h gpv) = f ` (results_gpv (map_ℐ g h ) gpv)"
proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
  show "x  f ` results_gpv (map_ℐ g h ) gpv" if "x  results_gpv  (map_gpv' f g h gpv)" for x using that
    by(induction gpv'"map_gpv' f g h gpv" arbitrary: gpv)(fastforce intro: results_gpv.intros rev_image_eqI)+
  show "f x  results_gpv  (map_gpv' f g h gpv)" if "x  results_gpv (map_ℐ g h ) gpv" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma WT_gpv_parametric': includes lifting_syntax shows
  "bi_unique C  (rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) WT_gpv WT_gpv"
proof(rule rel_funI iffI)+
  note [transfer_rule] = the_gpv_parametric'
  show *: " ⊢g gpv " if [transfer_rule]: "rel_ℐ C R  ℐ'" "bi_unique C" 
    and *: "ℐ' ⊢g gpv' " "rel_gpv'' A C R gpv gpv'" for  ℐ' gpv gpv' A C R
    using *
  proof(coinduction arbitrary: gpv gpv')
    case (WT_gpv out c gpv gpv')
    note [transfer_rule] = WT_gpv(2)
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))" 
      by transfer_prover
    from rel_setD1[OF this WT_gpv(3)] obtain out' c'
      where [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
        and out': "IO out' c'  set_spmf (the_gpv gpv')"
      by(auto elim: generat.rel_cases)
    have "out  outs_ℐ   out'  outs_ℐ ℐ'" by transfer_prover
    with WT_gpvD(1)[OF WT_gpv(1) out'] have ?out by simp
    moreover have ?cont
    proof(standard; goal_cases cont)
      case (cont input)
      have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
      from rel_setD1[OF this cont] obtain input' where [transfer_rule]: "R input input'"
        and input': "input'  responses_ℐ ℐ' out'" by blast
      have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
      with WT_gpvD(2)[OF WT_gpv(1) out' input'] show ?case by auto
    qed
    ultimately show ?case ..
  qed

  show "ℐ' ⊢g gpv' " if "rel_ℐ C R  ℐ'" "bi_unique C" " ⊢g gpv " "rel_gpv'' A C R gpv gpv'" 
    for  ℐ' gpv gpv'
    using *[of "conversep C" "conversep R" ℐ'  gpv "conversep A" gpv'] that
    by(simp add: rel_gpv''_conversep)
qed

lemma WT_gpv_map_gpv_id [simp]: " ⊢g map_gpv f id gpv    ⊢g gpv "
  using WT_gpv_parametric'[of "BNF_Def.Grp UNIV id" "(=)" "BNF_Def.Grp UNIV f", folded rel_gpv_conv_rel_gpv'']
  unfolding gpv.rel_Grp unfolding eq_alt[symmetric] relator_eq
  by(auto simp add: rel_fun_def Grp_def bi_unique_eq)

lemma WT_gpv_outs_gpvI:
  assumes "outs_gpv  gpv  outs_ℐ "
  shows " ⊢g gpv "
  using assms by(coinduction arbitrary: gpv)(auto intro: outs_gpv.intros)

lemma WT_gpv_iff_outs_gpv:
  " ⊢g gpv   outs_gpv  gpv  outs_ℐ "
  by(blast intro: WT_gpv_outs_gpvI dest: WT_gpv_outs_gpv)

subsection ‹Sub-gpvs›

context begin
qualified inductive sub_gpvsp :: "('out, 'in) ('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv  bool"
  for  x
where
  base:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out; x = c input  
   sub_gpvsp  x gpv"
| cont: 
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out; sub_gpvsp  x (c input) 
   sub_gpvsp  x gpv"

qualified lemma sub_gpvsp_base:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out  
   sub_gpvsp  (c input) gpv"
by(rule base) simp_all

definition sub_gpvs :: "('out, 'in)  ('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv set"
where "sub_gpvs  gpv  {x. sub_gpvsp  x gpv}"

lemma sub_gpvsp_sub_gpvs_eq [pred_set_conv]: "sub_gpvsp  x gpv  x  sub_gpvs  gpv"
by(simp add: sub_gpvs_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "sub_gpvs")

lemmas intros [intro?] = sub_gpvsp.intros[to_set]
  and base = sub_gpvsp_base[to_set]
  and cont = cont[to_set]
  and induct [consumes 1, case_names Pure IO, induct set: sub_gpvs] = sub_gpvsp.induct[to_set]
  and cases [consumes 1, case_names Pure IO, cases set: sub_gpvs] = sub_gpvsp.cases[to_set]
  and simps = sub_gpvsp.simps[to_set]
end
end

lemma WT_sub_gpvsD:
  assumes " ⊢g gpv " and "gpv'  sub_gpvs  gpv"
  shows " ⊢g gpv' "
using assms(2,1) by(induction)(auto dest: WT_gpvD)

lemma WT_sub_gpvsI:
  " out c. IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ; 
     gpv'. gpv'  sub_gpvs Γ gpv  Γ ⊢g gpv'  
   Γ ⊢g gpv "
by(rule WT_gpvI)(auto intro: sub_gpvs.base)

subsection ‹Losslessness›

text ‹A gpv is lossless iff we are guaranteed to get a result after a finite number of interactions
  that respect the interface. It is colossless if the interactions may go on for ever, but there is
  no non-termination.›

text ‹ We define both notions of losslessness simultaneously by mimicking what the (co)inductive
  package would do internally. Thus, we get a constant which is parametrised by the choice of the
  fixpoint, i.e., for non-recursive gpvs, we can state and prove both versions of losslessness
  in one go.›

context
  fixes co :: bool and  :: "('out, 'in) ℐ"
  and F :: "(('a, 'out, 'in) gpv  bool)  (('a, 'out, 'in) gpv  bool)"
  and co' :: bool
  defines "F  λgen_lossless_gpv gpv. pa. gpv = GPV pa  
     lossless_spmf pa  (out c input. IO out c  set_spmf pa  input  responses_ℐ  out  gen_lossless_gpv (c input))"
  and "co'  co" ― ‹We use a copy of @{term co} such that we can do case distinctions on @{term co'} without
    the simplifier rewriting the @{term co} in the local abbreviations for the constants.›
begin

lemma gen_lossless_gpv_mono: "mono F"
unfolding F_def
apply(rule monoI le_funI le_boolI')+
apply(tactic ‹REPEAT (resolve_tac @{context} (Inductive.get_monos @{context}) 1))
apply(erule le_funE)
apply(erule le_boolD)
done

definition gen_lossless_gpv :: "('a, 'out, 'in) gpv  bool"
where "gen_lossless_gpv = (if co' then gfp else lfp) F"

lemma gen_lossless_gpv_unfold: "gen_lossless_gpv = F gen_lossless_gpv"
by(simp add: gen_lossless_gpv_def gfp_unfold[OF gen_lossless_gpv_mono, symmetric] lfp_unfold[OF gen_lossless_gpv_mono, symmetric])

lemma gen_lossless_gpv_True: "co' = True  gen_lossless_gpv  gfp F"
  and gen_lossless_gpv_False: "co' = False  gen_lossless_gpv  lfp F"
by(simp_all add: gen_lossless_gpv_def)

lemma gen_lossless_gpv_cases [elim?, cases pred]:
  assumes "gen_lossless_gpv gpv"
  obtains (gen_lossless_gpv) p where "gpv = GPV p" "lossless_spmf p" 
    "out c input. IO out c  set_spmf p; input  responses_ℐ  out  gen_lossless_gpv (c input)"
proof -
  from assms show ?thesis
    by(rewrite in asm gen_lossless_gpv_unfold)(auto simp add: F_def intro: that)
qed

lemma gen_lossless_gpvD:
  assumes "gen_lossless_gpv gpv"
  shows gen_lossless_gpv_lossless_spmfD: "lossless_spmf (the_gpv gpv)"
  and gen_lossless_gpv_continuationD:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out   gen_lossless_gpv (c input)"
using assms by(auto elim: gen_lossless_gpv_cases)

lemma gen_lossless_gpv_intros:
  " lossless_spmf p;
     out c input. IO out c  set_spmf p; input  responses_ℐ  out   gen_lossless_gpv (c input) 
   gen_lossless_gpv (GPV p)"
by(rewrite gen_lossless_gpv_unfold)(simp add: F_def)

lemma gen_lossless_gpvI [intro?]:
  " lossless_spmf (the_gpv gpv);
     out c input.  IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out 
      gen_lossless_gpv (c input) 
   gen_lossless_gpv gpv"
by(cases gpv)(auto intro: gen_lossless_gpv_intros)

lemma gen_lossless_gpv_simps:
  "gen_lossless_gpv gpv 
   (p. gpv = GPV p  lossless_spmf p  (out c input.
        IO out c  set_spmf p  input  responses_ℐ  out  gen_lossless_gpv (c input)))"
by(rewrite gen_lossless_gpv_unfold)(simp add: F_def)

lemma gen_lossless_gpv_Done [iff]: "gen_lossless_gpv (Done x)"
by(rule gen_lossless_gpvI) auto

lemma gen_lossless_gpv_Fail [iff]: "¬ gen_lossless_gpv Fail"
by(auto dest: gen_lossless_gpvD)

lemma gen_lossless_gpv_Pause [simp]:
  "gen_lossless_gpv (Pause out c)  (input  responses_ℐ  out. gen_lossless_gpv (c input))"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

lemma gen_lossless_gpv_lift_spmf [iff]: "gen_lossless_gpv (lift_spmf p)  lossless_spmf p"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

end

lemma gen_lossless_gpv_assert_gpv [iff]: "gen_lossless_gpv co  (assert_gpv b)  b"
by(cases b) simp_all

abbreviation lossless_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool"
where "lossless_gpv  gen_lossless_gpv False"

abbreviation colossless_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool"
where "colossless_gpv  gen_lossless_gpv True"

lemma lossless_gpv_induct [consumes 1, case_names lossless_gpv, induct pred]:
  assumes *: "lossless_gpv  gpv"
  and step: "p.  lossless_spmf p;
     out c input. IO out c  set_spmf p; input  responses_ℐ  out  lossless_gpv  (c input);
     out c input. IO out c  set_spmf p; input  responses_ℐ  out  P (c input) 
      P (GPV p)"
  shows "P gpv"
proof -
  have "lossless_gpv   P"
    by(rule def_lfp_induct[OF gen_lossless_gpv_False gen_lossless_gpv_mono])(auto intro!: le_funI step)
  then show ?thesis using * by auto
qed

lemma colossless_gpv_coinduct 
  [consumes 1, case_names colossless_gpv, case_conclusion colossless_gpv lossless_spmf continuation, coinduct pred]:
  assumes *: "X gpv"
  and step: "gpv. X gpv  lossless_spmf (the_gpv gpv)  (out c input. 
       IO out c  set_spmf (the_gpv gpv)  input  responses_ℐ  out  X (c input)  colossless_gpv  (c input))"
  shows "colossless_gpv  gpv"
proof -
  have "X  colossless_gpv "
    by(rule def_coinduct[OF gen_lossless_gpv_True gen_lossless_gpv_mono])
      (auto 4 4 intro!: le_funI dest!: step intro: exI[where x="the_gpv _"])
  then show ?thesis using * by auto
qed

lemmas lossless_gpvI = gen_lossless_gpvI[where co=False]
  and lossless_gpvD = gen_lossless_gpvD[where co=False]
  and lossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=False]
  and lossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=False]

lemmas colossless_gpvI = gen_lossless_gpvI[where co=True]
  and colossless_gpvD = gen_lossless_gpvD[where co=True]
  and colossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=True]
  and colossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=True]

lemma gen_lossless_bind_gpvI:
  assumes "gen_lossless_gpv co  gpv" "x. x  results_gpv  gpv  gen_lossless_gpv co  (f x)"
  shows "gen_lossless_gpv co  (gpv  f)"
proof(cases co)
  case False
  hence eq: "co = False" by simp
  show ?thesis using assms unfolding eq
  proof(induction)
    case (lossless_gpv p)
    { fix x
      assume "Pure x  set_spmf p"
      hence "x  results_gpv  (GPV p)" by simp
      hence "lossless_gpv  (f x)" by(rule lossless_gpv.prems) }
    with ‹lossless_spmf p show ?case unfolding GPV_bind
      apply(intro gen_lossless_gpv_intros)
       apply(fastforce dest: lossless_gpvD split: generat.split)
      apply(clarsimp; split generat.split_asm)
      apply(auto dest: lossless_gpvD intro!: lossless_gpv)
      done
  qed
next
  case True
  hence eq: "co = True" by simp
  show ?thesis using assms unfolding eq
  proof(coinduction arbitrary: gpv rule: colossless_gpv_coinduct)
    case * [rule_format]: (colossless_gpv gpv)
    from *(1) have ?lossless_spmf 
      by(auto 4 3 dest: colossless_gpv_lossless_spmfD elim!: is_PureE intro: *(2)[THEN colossless_gpv_lossless_spmfD] results_gpv.Pure)
    moreover have ?continuation
    proof(intro strip)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv (gpv  f))"
        and input: "input  responses_ℐ  out"
      from IO obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
        and IO: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                 else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
        by(auto)
      show "(gpv. c input = gpv  f  colossless_gpv  gpv  (x. x  results_gpv  gpv  colossless_gpv  (f x))) 
        colossless_gpv  (c input)"
      proof(cases generat)
        case (Pure x)
        hence "x  results_gpv  gpv" using generat by(auto intro: results_gpv.Pure)
        from *(2)[OF this] have "colossless_gpv  (c input)"
          using IO Pure input by(auto intro: colossless_gpv_continuationD)
        thus ?thesis ..
      next
        case **: (IO out' c')
        with input generat IO have "colossless_gpv  (f x)" if "x  results_gpv  (c' input)" for x
          using that by(auto intro: * results_gpv.IO)
        then show ?thesis using IO input ** *(1) generat by(auto dest: colossless_gpv_continuationD)
      qed
    qed
    ultimately show ?case ..
  qed
qed

lemmas lossless_bind_gpvI = gen_lossless_bind_gpvI[where co=False]
  and colossless_bind_gpvI = gen_lossless_bind_gpvI[where co=True]

lemma gen_lossless_bind_gpvD1: 
  assumes "gen_lossless_gpv co  (gpv  f)"
  shows "gen_lossless_gpv co  gpv"
proof(cases co)
  case False
  hence eq: "co = False" by simp
  show ?thesis using assms unfolding eq
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case (lossless_gpv p)
    obtain p' where gpv: "gpv = GPV p'" by(cases gpv)
    from lossless_gpv.hyps gpv have "lossless_spmf p'" by(simp add: GPV_bind)
    then show ?case unfolding gpv
    proof(rule gen_lossless_gpv_intros)
      fix out c input
      assume "IO out c  set_spmf p'" "input  responses_ℐ  out"
      hence "IO out (λinput. c input  f)  set_spmf p" using lossless_gpv.hyps gpv
        by(auto simp add: GPV_bind intro: rev_bexI)
      thus "lossless_gpv  (c input)" using input  _ by(rule lossless_gpv.hyps) simp
    qed
  qed
next
  case True
  hence eq: "co = True" by simp
  show ?thesis using assms unfolding eq
    by(coinduction arbitrary: gpv)(auto 4 3 intro: rev_bexI elim!: colossless_gpv_continuationD dest: colossless_gpv_lossless_spmfD)
qed

lemmas lossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=False]
  and colossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=True]

lemma gen_lossless_bind_gpvD2:
  assumes "gen_lossless_gpv co  (gpv  f)"
  and "x  results_gpv  gpv"
  shows "gen_lossless_gpv co  (f x)"
using assms(2,1)
proof(induction)
  case (Pure gpv)
  thus ?case
    by -(rule gen_lossless_gpvI, auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)
qed(auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)

lemmas lossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=False]
  and colossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=True]

lemma gen_lossless_bind_gpv [simp]:
  "gen_lossless_gpv co  (gpv  f)  gen_lossless_gpv co  gpv  (xresults_gpv  gpv. gen_lossless_gpv co  (f x))"
by(blast intro: gen_lossless_bind_gpvI dest: gen_lossless_bind_gpvD1 gen_lossless_bind_gpvD2)

lemmas lossless_bind_gpv = gen_lossless_bind_gpv[where co=False]
  and colossless_bind_gpv = gen_lossless_bind_gpv[where co=True]

context includes lifting_syntax begin

lemma rel_gpv''_lossless_gpvD1:
  assumes rel: "rel_gpv'' A C R gpv gpv'"
  and gpv: "lossless_gpv  gpv"
  and [transfer_rule]: "rel_ℐ C R  ℐ'"
  shows "lossless_gpv ℐ' gpv'"
using gpv rel
proof(induction arbitrary: gpv')
  case (lossless_gpv p)
  from lossless_gpv.prems obtain q where q: "gpv' = GPV q"
    and [transfer_rule]: "rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"
    by(cases gpv') auto
  show ?case
  proof(rule lossless_gpvI)
    have "lossless_spmf p = lossless_spmf q" by transfer_prover
    with lossless_gpv.hyps(1) q show "lossless_spmf (the_gpv gpv')" by simp

    fix out' c' input'
    assume IO': "IO out' c'  set_spmf (the_gpv gpv')"
      and input': "input'  responses_ℐ ℐ' out'"
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf p) (set_spmf q)"
      by transfer_prover
    with IO' q obtain out c where IO: "IO out c  set_spmf p"
      and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
      by(auto dest!: rel_setD2 elim: generat.rel_cases)
    have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
    moreover
    from this[THEN rel_setD2, OF input'] obtain input
      where [transfer_rule]: "R input input'" and input: "input  responses_ℐ  out" by blast
    have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
    ultimately show "lossless_gpv ℐ' (c' input')" using input IO by(auto intro: lossless_gpv.IH)
  qed
qed

lemma rel_gpv''_lossless_gpvD2:
  " rel_gpv'' A C R gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C R  ℐ' 
   lossless_gpv  gpv"
using rel_gpv''_lossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ]
by(simp add: rel_gpv''_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv_lossless_gpvD1:
  " rel_gpv A C gpv gpv'; lossless_gpv  gpv; rel_ℐ C (=)  ℐ'   lossless_gpv ℐ' gpv'"
using rel_gpv''_lossless_gpvD1[of A C "(=)" gpv gpv'  ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_lossless_gpvD2:
  " rel_gpv A C gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C (=)  ℐ' 
   lossless_gpv  gpv"
using rel_gpv_lossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ]
by(simp add: gpv.rel_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv''_colossless_gpvD1:
  assumes rel: "rel_gpv'' A C R gpv gpv'"
  and gpv: "colossless_gpv  gpv"
  and [transfer_rule]: "rel_ℐ C R  ℐ'"
  shows "colossless_gpv ℐ' gpv'"
using gpv rel
proof(coinduction arbitrary: gpv gpv')
  case (colossless_gpv gpv gpv')
  note [transfer_rule] = ‹rel_gpv'' A C R gpv gpv' the_gpv_parametric'
    and co = ‹colossless_gpv  gpv
  have "lossless_spmf (the_gpv gpv) = lossless_spmf (the_gpv gpv')" by transfer_prover
  with co have "?lossless_spmf" by(auto dest: colossless_gpv_lossless_spmfD)
  moreover have "?continuation"
  proof(intro strip disjI1)
    fix out' c' input'
    assume IO': "IO out' c'  set_spmf (the_gpv gpv')"
      and input': "input'  responses_ℐ ℐ' out'"
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))"
      by transfer_prover
    with IO' obtain out c where IO: "IO out c  set_spmf (the_gpv gpv)"
      and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
      by(auto dest!: rel_setD2 elim: generat.rel_cases)
    have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
    moreover 
    from this[THEN rel_setD2, OF input'] obtain input
      where [transfer_rule]: "R input input'" and input: "input  responses_ℐ  out" by blast
    have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
    ultimately show "gpv gpv'. c' input' = gpv'  colossless_gpv  gpv  rel_gpv'' A C R gpv gpv'"
      using input IO co by(auto dest: colossless_gpv_continuationD)
  qed
  ultimately show ?case ..
qed

lemma rel_gpv''_colossless_gpvD2:
  " rel_gpv'' A C R gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C R  ℐ' 
   colossless_gpv  gpv"
using rel_gpv''_colossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ]
by(simp add: rel_gpv''_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv_colossless_gpvD1:
  " rel_gpv A C gpv gpv'; colossless_gpv  gpv; rel_ℐ C (=)  ℐ'   colossless_gpv ℐ' gpv'"
using rel_gpv''_colossless_gpvD1[of A C "(=)" gpv gpv'  ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_colossless_gpvD2:
  " rel_gpv A C gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C (=)  ℐ' 
   colossless_gpv  gpv"
using rel_gpv_colossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ]
by(simp add: gpv.rel_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma gen_lossless_gpv_parametric':
  "((=) ===> rel_ℐ C R ===> rel_gpv'' A C R ===> (=))
   gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
  show "(rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
    by(cases b)(auto intro!: rel_funI dest: rel_gpv''_colossless_gpvD1 rel_gpv''_colossless_gpvD2 rel_gpv''_lossless_gpvD1 rel_gpv''_lossless_gpvD2)
qed

lemma gen_lossless_gpv_parametric [transfer_rule]:
  "((=) ===> rel_ℐ C (=) ===> rel_gpv A C ===> (=))
   gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
  show "(rel_ℐ C (=) ===> rel_gpv A C ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
    by(cases b)(auto intro!: rel_funI dest: rel_gpv_colossless_gpvD1 rel_gpv_colossless_gpvD2 rel_gpv_lossless_gpvD1 rel_gpv_lossless_gpvD2)
qed

end

lemma gen_lossless_gpv_map_full [simp]:
  "gen_lossless_gpv b ℐ_full (map_gpv f g gpv) = gen_lossless_gpv b ℐ_full gpv"
  (is "?lhs = ?rhs")
proof(cases "b = True")
  case True
  show "?lhs = ?rhs"
  proof
    show ?rhs if ?lhs using that unfolding True
      by(coinduction arbitrary: gpv)(auto 4 3 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
    show ?lhs if ?rhs using that unfolding True
      by(coinduction arbitrary: gpv)(auto 4 4 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
  qed
next
  case False
  hence False: "b = False" by simp
  show "?lhs = ?rhs"
  proof
    show ?rhs if ?lhs using that unfolding False
      apply(induction gpv'"map_gpv f g gpv" arbitrary: gpv)
      subgoal for p gpv by(cases gpv)(rule lossless_gpvI; fastforce intro: rev_image_eqI)
      done
    show ?lhs if ?rhs using that unfolding False
      by induction(auto 4 4 intro: lossless_gpvI)
  qed
qed

lemma gen_lossless_gpv_map_id [simp]:
  "gen_lossless_gpv b  (map_gpv f id gpv) = gen_lossless_gpv b  gpv"
  using gen_lossless_gpv_parametric[of "BNF_Def.Grp UNIV id" "BNF_Def.Grp UNIV f"] unfolding gpv.rel_Grp
  by(simp add: rel_fun_def eq_alt[symmetric] rel_ℐ_eq)(auto simp add: Grp_def)

lemma results_gpv_try_gpv [simp]:
  "results_gpv  (TRY gpv ELSE gpv') = 
   results_gpv  gpv  (if colossless_gpv  gpv then {} else results_gpv  gpv')"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  show "x  ?rhs" if "x  ?lhs" for x using that
  proof(induction gpv''"try_gpv gpv gpv'" arbitrary: gpv)
    case Pure thus ?case
      by(auto split: if_split_asm intro: results_gpv.Pure dest: colossless_gpv_lossless_spmfD)
  next
    case (IO out c input)
    then show ?case
      apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm)
      apply(force intro: results_gpv.IO dest: colossless_gpv_continuationD split: if_split_asm)+
      done
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (left) "x  results_gpv  gpv" | (right) "¬ colossless_gpv  gpv" "x  results_gpv  gpv'"
    by(auto split: if_split_asm)
  thus "x  ?lhs"
  proof cases
    case left
    thus ?thesis 
      by(induction)(auto 4 4 intro: results_gpv.intros rev_image_eqI split del: if_split)
  next
    case right
    from right(1) show ?thesis
    proof(rule contrapos_np)
      assume "x  ?lhs"
      with right(2) show "colossless_gpv  gpv"
      proof(coinduction arbitrary: gpv)
        case (colossless_gpv gpv)
        then have ?lossless_spmf
          apply(rewrite in asm try_gpv.code)
          apply(rule ccontr)
          apply(erule results_gpv.cases)
          apply(fastforce simp add: image_Un image_image generat.map_comp o_def)+
          done
        moreover have "?continuation" using colossless_gpv
          by(auto 4 4 split del: if_split simp add: image_Un image_image generat.map_comp o_def intro: rev_image_eqI results_gpv.IO)
        ultimately show ?case ..
      qed
    qed
  qed
qed

lemma results'_gpv_try_gpv [simp]:
  "results'_gpv (TRY gpv ELSE gpv') = 
   results'_gpv gpv  (if colossless_gpv ℐ_full gpv then {} else results'_gpv gpv')"
by(simp add: results_gpv_ℐ_full[symmetric])

lemma outs'_gpv_try_gpv [simp]:
  "outs'_gpv (TRY gpv ELSE gpv') =
   outs'_gpv gpv  (if colossless_gpv ℐ_full gpv then {} else outs'_gpv gpv')"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  show "x  ?rhs" if "x  ?lhs" for x using that
  proof(induction gpv''"try_gpv gpv gpv'" arbitrary: gpv)
    case Out thus ?case
      by(auto 4 3 simp add: generat.map_comp o_def elim!: generat.set_cases(2) intro: outs'_gpv_Out split: if_split_asm dest: colossless_gpv_lossless_spmfD)
  next
    case (Cont generat c input)
    then show ?case
      apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm elim!: generat.set_cases(3))
      apply(auto 4 3 dest: colossless_gpv_continuationD split: if_split_asm intro: outs'_gpv_Cont elim!: meta_allE meta_impE[OF _ refl])+
      done
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (left) "x  outs'_gpv gpv" | (right) "¬ colossless_gpv ℐ_full gpv" "x  outs'_gpv gpv'"
    by(auto split: if_split_asm)
  thus "x  ?lhs"
  proof cases
    case left
    thus ?thesis 
      by(induction)(auto elim!: generat.set_cases(2,3) intro: outs'_gpvI intro!: rev_image_eqI split del: if_split simp add: image_Un image_image generat.map_comp o_def)
  next
    case right
    from right(1) show ?thesis
    proof(rule contrapos_np)
      assume "x  ?lhs"
      with right(2) show "colossless_gpv ℐ_full gpv"
      proof(coinduction arbitrary: gpv)
        case (colossless_gpv gpv)
        then have ?lossless_spmf
          apply(rewrite in asm try_gpv.code)
          apply(erule contrapos_np)
          apply(erule gpv.set_cases)
          apply(auto 4 3 simp add: image_Un image_image generat.map_comp o_def generat.set_map in_set_spmf[symmetric] bind_UNION generat.map_id[unfolded id_def] elim!: generat.set_cases)
          done
        moreover have "?continuation" using colossless_gpv
          by(auto simp add: image_Un image_image generat.map_comp o_def split del: if_split intro!: rev_image_eqI intro: outs'_gpv_Cont)
        ultimately show ?case ..
      qed
    qed
  qed
qed

lemma pred_gpv_try [simp]:
  "pred_gpv P Q (try_gpv gpv gpv') = (pred_gpv P Q gpv  (¬ colossless_gpv ℐ_full gpv  pred_gpv P Q gpv'))"
by(auto simp add: pred_gpv_def)

lemma lossless_WT_gpv_induct [consumes 2, case_names lossless_gpv]:
  assumes lossless: "lossless_gpv  gpv"
  and WT: " ⊢g gpv "
  and step: "p. 
       lossless_spmf p;
       out c. IO out c  set_spmf p  out  outs_ℐ ;
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out  lossless_gpv  (c input);
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out   ⊢g c input ;
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out  P (c input)
       P (GPV p)"
  shows "P gpv"
using lossless WT
apply(induction)
apply(erule step)
apply(auto elim: WT_gpvD simp add: WT_gpv_simps)
done

lemma lossless_gpv_induct_strong [consumes 1, case_names lossless_gpv]:
  assumes gpv: "lossless_gpv  gpv"
  and step:
  "p.  lossless_spmf p;
          gpv. gpv  sub_gpvs  (GPV p)  lossless_gpv  gpv;
          gpv. gpv  sub_gpvs  (GPV p)  P gpv 
        P (GPV p)"
  shows "P gpv"
proof -
  define gpv' where "gpv' = gpv"
  then have "gpv'  insert gpv (sub_gpvs  gpv)" by simp
  with gpv have "lossless_gpv  gpv'  P gpv'"
  proof(induction arbitrary: gpv')
    case (lossless_gpv p)
    from gpv'  insert (GPV p) _ show ?case
    proof(rule insertE)
      assume "gpv' = GPV p"
      moreover have "lossless_gpv  (GPV p)"
        by(auto 4 3 intro: lossless_gpvI lossless_gpv.hyps)
      moreover have "P (GPV p)" using lossless_gpv.hyps(1)
        by(rule step)(fastforce elim: sub_gpvs.cases lossless_gpv.IH[THEN conjunct1] lossless_gpv.IH[THEN conjunct2])+
      ultimately show ?case by simp
    qed(fastforce elim: sub_gpvs.cases lossless_gpv.IH[THEN conjunct1] lossless_gpv.IH[THEN conjunct2])
  qed
  thus ?thesis by(simp add: gpv'_def)
qed

lemma lossless_sub_gpvsI:
  assumes spmf: "lossless_spmf (the_gpv gpv)"
  and sub: "gpv'. gpv'  sub_gpvs  gpv  lossless_gpv  gpv'"
  shows "lossless_gpv  gpv"
using spmf by(rule lossless_gpvI)(rule sub[OF sub_gpvs.base])

lemma lossless_sub_gpvsD:
  assumes "lossless_gpv  gpv" "gpv'  sub_gpvs  gpv"
  shows "lossless_gpv  gpv'"
using assms(2,1) by(induction)(auto dest: lossless_gpvD)

lemma lossless_WT_gpv_induct_strong [consumes 2, case_names lossless_gpv]:
  assumes lossless: "lossless_gpv  gpv"
  and WT: " ⊢g gpv "
  and step: "p.  lossless_spmf p;
       out c. IO out c  set_spmf p  out  outs_ℐ ;
       gpv. gpv  sub_gpvs  (GPV p)  lossless_gpv  gpv;
       gpv. gpv  sub_gpvs  (GPV p)   ⊢g gpv ;
       gpv. gpv  sub_gpvs  (GPV p)  P gpv 
       P (GPV p)"
  shows "P gpv"
using lossless WT
apply(induction rule: lossless_gpv_induct_strong)
apply(erule step)
apply(auto elim: WT_gpvD dest: WT_sub_gpvsD)
done

lemma try_gpv_gen_lossless: ― ‹TODO: generalise to arbitrary typings ?›
  "gen_lossless_gpv b ℐ_full gpv  (TRY gpv ELSE gpv') = gpv"
proof(coinduction arbitrary: gpv)
  case (Eq_gpv gpv)
  from Eq_gpv[THEN gen_lossless_gpv_lossless_spmfD]
  have eq: "the_gpv gpv = (TRY the_gpv gpv ELSE the_gpv gpv')" by(simp)
  show ?case
    by(subst eq)(auto simp add: spmf_rel_map generat.rel_map[abs_def] intro!: rel_spmf_try_spmf rel_spmf_reflI rel_generat_reflI elim!: generat.set_cases gen_lossless_gpv_continuationD[OF Eq_gpv] simp add: Eq_gpv[THEN gen_lossless_gpv_lossless_spmfD])
qed

― ‹We instantiate the parameter @{term b} such that it can be used as a conditional simp rule.›
lemmas try_gpv_lossless [simp] = try_gpv_gen_lossless[where b=False]
  and try_gpv_colossless [simp] = try_gpv_gen_lossless[where b=True]

lemma try_gpv_bind_gen_lossless: ― ‹TODO: generalise to arbitrary typings?›
  "gen_lossless_gpv b ℐ_full gpv  TRY bind_gpv gpv f ELSE gpv' = bind_gpv gpv (λx. TRY f x ELSE gpv')"
proof(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
  case (Eq_gpv gpv)
  note [simp] = spmf_rel_map generat.rel_map map_spmf_bind_spmf
    and [intro!] = rel_spmf_reflI rel_generat_reflI rel_funI
  show ?case using gen_lossless_gpvD[OF Eq_gpv]
    by(auto 4 3 simp del: bind_gpv_sel' simp add: bind_gpv.sel try_spmf_bind_spmf_lossless split: generat.split intro!: rel_spmf_bind_reflI rel_spmf_try_spmf)
qed

― ‹We instantiate the parameter @{term b} such that it can be used as a conditional simp rule.›
lemmas try_gpv_bind_lossless = try_gpv_bind_gen_lossless[where b=False]
  and try_gpv_bind_colossless = try_gpv_bind_gen_lossless[where b=True]

lemma try_gpv_cong:
  " gpv = gpv''; ¬ colossless_gpv ℐ_full gpv''  gpv' = gpv''' 
   try_gpv gpv gpv' = try_gpv gpv'' gpv'''"
by(cases "colossless_gpv ℐ_full gpv''") simp_all

(* lemma try_gpv_bind_colossless2:
  "colossless_gpv ℐ_full gpv' ⟹ try_gpv (bind_gpv gpv f) gpv' = try_gpv (bind_gpv gpv (λx. try_gpv (f x) gpv')) gpv'"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(simp add: spmf_rel_map bind_gpv_sel del: bind_gpv_sel')
apply(rule rel_spmf_try_spmf)
 apply(simp add: spmf_rel_map)
 apply(rule rel_spmf_bind_reflI)
 apply(simp split: generat.split)
 apply(rule conjI; clarsimp)
 apply(simp add: spmf_rel_map generat.map_comp o_def generat.rel_map rel_fun_def)
 apply(subst map_try_spmf[symmetric])
 *)

context fixes B :: "'b  'c set" and x :: 'a begin

primcorec mk_lossless_gpv :: "('a, 'b, 'c) gpv  ('a, 'b, 'c) gpv" where
  "the_gpv (mk_lossless_gpv gpv) =
   map_spmf (λgenerat. case generat of Pure x  Pure x 
      | IO out c  IO out (λinput. if input  B out then mk_lossless_gpv (c input) else Done x))
    (the_gpv gpv)"

end

lemma WT_gpv_mk_lossless_gpv:
  assumes " ⊢g gpv "
    and outs: "outs_ℐ ℐ' = outs_ℐ "
  shows "ℐ' ⊢g mk_lossless_gpv (responses_ℐ ) x gpv "
  using assms(1)
  by(coinduction arbitrary: gpv)(auto 4 3 split: generat.split_asm simp add: outs dest: WT_gpvD)

subsection ‹Sequencing with failure handling included›

definition catch_gpv :: "('a, 'out, 'in) gpv  ('a option, 'out, 'in) gpv"
where "catch_gpv gpv = TRY map_gpv Some id gpv ELSE Done None"

lemma catch_gpv_Done [simp]: "catch_gpv (Done x) = Done (Some x)"
by(simp add: catch_gpv_def)

lemma catch_gpv_Fail [simp]: "catch_gpv Fail = Done None"
by(simp add: catch_gpv_def)

lemma catch_gpv_Pause [simp]: "catch_gpv (Pause out rpv) = Pause out (λinput. catch_gpv (rpv input))"
by(simp add: catch_gpv_def)

lemma catch_gpv_lift_spmf [simp]: "catch_gpv (lift_spmf p) = lift_spmf (spmf_of_pmf p)"
by(rule gpv.expand)(auto simp add: catch_gpv_def spmf_of_pmf_def map_lift_spmf try_spmf_def o_def map_pmf_def bind_assoc_pmf bind_return_pmf intro!: bind_pmf_cong[OF refl] split: option.split)

lemma catch_gpv_assert [simp]: "catch_gpv (assert_gpv b) = Done (assert_option b)"
by(cases b) simp_all

lemma catch_gpv_sel [simp]:
  "the_gpv (catch_gpv gpv) = 
   TRY map_spmf (map_generat Some id (λrpv input. catch_gpv (rpv input))) (the_gpv gpv)
   ELSE return_spmf (Pure None)"
by(simp add: catch_gpv_def gpv.map_sel spmf.map_comp o_def generat.map_comp map_try_spmf id_def)

lemma catch_gpv_bind_gpv: "catch_gpv (bind_gpv gpv f) = bind_gpv (catch_gpv gpv) (λx. case x of None  Done None | Some x'  catch_gpv (f x'))"
  using [[show_variants]]
  apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
  apply(clarsimp simp add: map_bind_pmf bind_gpv.sel spmf.map_comp o_def[abs_def] map_bind_spmf generat.map_comp simp del: bind_gpv_sel')
  apply(subst bind_spmf_def)
  apply(subst try_spmf_bind_pmf)
  apply(subst (2) try_spmf_def)
  apply(subst bind_spmf_pmf_assoc)
  apply(simp add: bind_map_pmf)
  apply(rule rel_pmf_bind_reflI)
  apply(auto split!: option.split generat.split simp add: spmf_rel_map spmf.map_comp o_def generat.map_comp id_def[symmetric] generat.map_id rel_spmf_reflI generat.rel_refl refl rel_fun_def)
  done

context includes lifting_syntax begin
lemma catch_gpv_parametric [transfer_rule]:
  "(rel_gpv A C ===> rel_gpv (rel_option A) C) catch_gpv catch_gpv"
unfolding catch_gpv_def by transfer_prover

lemma catch_gpv_parametric':
  notes [transfer_rule] = try_gpv_parametric' map_gpv_parametric' Done_parametric'
  shows "(rel_gpv'' A C R ===> rel_gpv'' (rel_option A) C R) catch_gpv catch_gpv"
unfolding catch_gpv_def by transfer_prover
end

lemma catch_gpv_map': "catch_gpv (map_gpv' f g h gpv) = map_gpv' (map_option f) g h (catch_gpv gpv)"
by(simp add: catch_gpv_def map'_try_gpv map_gpv_conv_map_gpv' map_gpv'_comp o_def)

lemma catch_gpv_map: "catch_gpv (map_gpv f g gpv) = map_gpv (map_option f) g (catch_gpv gpv)"
  by(simp add: map_gpv_conv_map_gpv' catch_gpv_map')

lemma colossless_gpv_catch_gpv [simp]: "colossless_gpv ℐ_full (catch_gpv gpv)"
by(coinduction arbitrary: gpv) auto

lemma colosless_gpv_catch_gpv_conv_map:
  "colossless_gpv ℐ_full gpv  catch_gpv gpv = map_gpv Some id gpv"
  apply(coinduction arbitrary: gpv)
  apply(frule colossless_gpv_lossless_spmfD)
  apply(auto simp add: spmf_rel_map gpv.map_sel generat.rel_map intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI elim!: colossless_gpv_continuationD generat.set_cases)
  done

lemma catch_gpv_catch_gpv [simp]: "catch_gpv (catch_gpv gpv) = map_gpv Some id (catch_gpv gpv)"
  by(simp add: colosless_gpv_catch_gpv_conv_map)

lemma case_map_resumption: (* Move to Resumption *)
  "case_resumption done pause (map_resumption f g r) = 
   case_resumption (done  map_option f) (λout c. pause (g out) (map_resumption f g  c)) r" 
by(cases r) simp_all

lemma catch_gpv_lift_resumption [simp]: "catch_gpv (lift_resumption r) = lift_resumption (map_resumption Some id r)"
  apply(coinduction arbitrary: r)
  apply(auto simp add: lift_resumption.sel case_map_resumption split: resumption.split option.split)
  oops (* TODO: We'd need a catch_resumption for this *)

lemma results_gpv_catch_gpv:
  "results_gpv  (catch_gpv gpv) = Some ` results_gpv  gpv  (if colossless_gpv  gpv then {} else {None})"
  by(simp add: catch_gpv_def)

lemma Some_in_results_gpv_catch_gpv [simp]:
  "Some x  results_gpv  (catch_gpv gpv)  x  results_gpv  gpv"
  by(auto simp add: results_gpv_catch_gpv)

lemma None_in_results_gpv_catch_gpv [simp]:
  "None  results_gpv  (catch_gpv gpv)  ¬ colossless_gpv  gpv"
  by(auto simp add: results_gpv_catch_gpv)

lemma results'_gpv_catch_gpv:
  "results'_gpv (catch_gpv gpv) = Some ` results'_gpv gpv  (if colossless_gpv ℐ_full gpv then {} else {None})"
  by(simp add: results_gpv_ℐ_full[symmetric] results_gpv_catch_gpv)

lemma Some_in_results'_gpv_catch_gpv [simp]:
  "Some x  results'_gpv (catch_gpv gpv)  x  results'_gpv gpv"
  by(simp add: results_gpv_ℐ_full[symmetric])

lemma None_in_results'_gpv_catch_gpv [simp]:
  "None  results'_gpv (catch_gpv gpv)  ¬ colossless_gpv ℐ_full gpv"
  by(simp add: results_gpv_ℐ_full[symmetric])

lemma results'_gpv_catch_gpvE:
  assumes "x  results'_gpv (catch_gpv gpv)"
  obtains (Some) x'
  where "x = Some x'" "x'  results'_gpv gpv"
  | (colossless) "x = None" "¬ colossless_gpv ℐ_full gpv"
  using assms by(auto simp add: results'_gpv_catch_gpv split: if_split_asm)

lemma outs'_gpv_catch_gpv [simp]: "outs'_gpv (catch_gpv gpv) = outs'_gpv gpv"
  by(simp add: catch_gpv_def)

lemma pred_gpv_catch_gpv [simp]: "pred_gpv (pred_option P) Q (catch_gpv gpv) = pred_gpv P Q gpv"
  by(simp add: pred_gpv_def results'_gpv_catch_gpv)

abbreviation bind_gpv' :: "('a, 'call, 'ret) gpv  ('a option  ('b, 'call, 'ret) gpv)  ('b, 'call, 'ret) gpv"
where "bind_gpv' gpv  bind_gpv (catch_gpv gpv)"

(* lemma bind_gpv'_sel [simp]:
  "the_gpv (bind_gpv' gpv f) =
   bind_pmf (the_gpv gpv) (λx. case x of
     None ⇒ the_gpv (f None)
   | Some (Pure x) ⇒ the_gpv (f (Some x))
   | Some (IO out rpv) ⇒ return_spmf (IO out (λinput. bind_gpv' (rpv input) f)))"
by(auto simp add: bind_gpv'_def bind_map_spmf try_spmf_def bind_spmf_pmf_assoc bind_map_pmf gpv.map_sel intro!: bind_pmf_cong[OF refl] split: option.split generat.split)
 *)
  
lemma bind_gpv'_assoc [simp]: "bind_gpv' (bind_gpv' gpv f) g = bind_gpv' gpv (λx. bind_gpv' (f x) g)"
by(simp add: catch_gpv_bind_gpv bind_map_gpv o_def bind_gpv_assoc)

lemma bind_gpv'_bind_gpv: "bind_gpv' (bind_gpv gpv f) g = bind_gpv' gpv (case_option (g None) (λy. bind_gpv' (f y) g))"
  by(clarsimp simp add: catch_gpv_bind_gpv bind_gpv_assoc intro!: bind_gpv_cong[OF refl] split: option.split)

lemma bind_gpv'_cong:
  " gpv = gpv'; x. x  Some ` results'_gpv gpv'  (¬ colossless_gpv ℐ_full gpv  x = None)  f x = f' x 
   bind_gpv' gpv f = bind_gpv' gpv' f'"
by(auto elim: results'_gpv_catch_gpvE split: if_split_asm intro!: bind_gpv_cong[OF refl])

lemma bind_gpv'_cong2:
  " gpv = gpv'; x. x  results'_gpv gpv'  f (Some x) = f' (Some x); ¬ colossless_gpv ℐ_full gpv  f None = f' None 
   bind_gpv' gpv f = bind_gpv' gpv' f'"
by(rule bind_gpv'_cong) auto

subsection ‹Inlining›

lemma gpv_coinduct_bind [consumes 1, case_names Eq_gpv]:
  fixes gpv gpv' :: "('a, 'call, 'ret) gpv"
  assumes *: "R gpv gpv'"
  and step: "gpv gpv'. R gpv gpv' 
     rel_spmf (rel_generat (=) (=) (rel_fun (=) (λgpv gpv'. R gpv gpv'  gpv = gpv'  
      (gpv2 :: ('b, 'call, 'ret) gpv. gpv2' :: ('c, 'call, 'ret) gpv. f f'. gpv = bind_gpv gpv2 f  gpv' = bind_gpv gpv2' f'  
        rel_gpv (λx y. R (f x) (f' y)) (=) gpv2 gpv2'))))
      (the_gpv gpv) (the_gpv gpv')"
  shows "gpv = gpv'"
proof -
  fix x y
  define gpv1 :: "('b, 'call, 'ret) gpv"
    and f :: "'b  ('a, 'call, 'ret) gpv"
    and gpv1' :: "('c, 'call, 'ret) gpv"
    and f' :: "'c  ('a, 'call, 'ret) gpv"
    where "gpv1 = Done x"
      and "f = (λ_. gpv)"
      and "gpv1' = Done y"
      and "f' = (λ_. gpv')"
  from * have "rel_gpv (λx y. R (f x) (f' y)) (=) gpv1 gpv1'"
    by(simp add: gpv1_def gpv1'_def f_def f'_def)
  then have "gpv1  f = gpv1'  f'"
  proof(coinduction arbitrary: gpv1 gpv1' f f' rule: gpv.coinduct_strong)
    case (Eq_gpv gpv1 gpv1' f f')
    from Eq_gpv[simplified gpv.rel_sel] show ?case unfolding bind_gpv.sel spmf_rel_map
      apply(rule rel_spmf_bindI)
      subgoal for generat generat'
        apply(cases generat generat' rule: generat.exhaust[case_product generat.exhaust]; clarsimp simp add: o_def spmf_rel_map generat.rel_map)
        subgoal premises Pure for x y
          using step[OF R (f x) (f' y)] apply -
          apply(assumption | rule rel_spmf_mono rel_generat_mono rel_fun_mono refl)+
          apply(fastforce intro: exI[where x="Done _"])+
          done
        subgoal by(fastforce simp add: rel_fun_def)
        done
      done
  qed
  thus ?thesis by(simp add: gpv1_def gpv1'_def f_def f'_def)
qed

text ‹
  Inlining one gpv into another. This may throw out arbitrarily many
  interactions between the two gpvs if the inlined one does not call its callee.
  So we define it as the coiteration of a least-fixpoint search operator.
›

context
  fixes callee :: "'s  'call  ('ret × 's, 'call', 'ret') gpv"
  notes [[function_internals]]
begin

partial_function (spmf) inline1
  :: "('a, 'call, 'ret) gpv  's
   ('a × 's + 'call' × ('ret × 's, 'call', 'ret') rpv × ('a, 'call, 'ret) rpv) spmf"
where
  "inline1 gpv s =
   the_gpv gpv 
   case_generat (λx. return_spmf (Inl (x, s)))
     (λout rpv. the_gpv (callee s out) 
         case_generat (λ(x, y). inline1 (rpv x) y)
          (λout rpv'. return_spmf (Inr (out, rpv', rpv))))"

lemma inline1_unfold:
  "inline1 gpv s =
   the_gpv gpv 
   case_generat (λx. return_spmf (Inl (x, s)))
     (λout rpv. the_gpv (callee s out) 
         case_generat (λ(x, y). inline1 (rpv x) y)
          (λout rpv'. return_spmf (Inr (out, rpv', rpv))))"
by(fact inline1.simps)

lemma inline1_fixp_induct [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λinline1'. P (λgpv s. inline1' (gpv, s)))"
  and "P (λ_ _. return_pmf None)"
  and "inline1'. P inline1'  P (λgpv s. the_gpv gpv  case_generat (λx. return_spmf (Inl (x, s))) (λout rpv. the_gpv (callee s out)  case_generat (λ(x, y). inline1' (rpv x) y) (λout rpv'. return_spmf (Inr (out, rpv', rpv)))))"
  shows "P inline1"
using assms by(rule inline1.fixp_induct[unfolded curry_conv[abs_def]])

lemma inline1_fixp_induct_strong [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λinline1'. P (λgpv s. inline1' (gpv, s)))"
  and "P (λ_ _. return_pmf None)"
  and "inline1'.  gpv s. ord_spmf (=) (inline1' gpv s) (inline1 gpv s); P inline1' 
     P (λgpv s. the_gpv gpv  case_generat (λx. return_spmf (Inl (x, s))) (λout rpv. the_gpv (callee s out)  case_generat (λ(x, y). inline1' (rpv x) y) (λout rpv'. return_spmf (Inr (out, rpv', rpv)))))"
  shows "P inline1"
using assms by(rule spmf.fixp_strong_induct_uc[where P="λf. P (curry f)" and U=case_prod and C=curry, OF inline1.mono inline1_def, simplified curry_case_prod, simplified curry_conv[abs_def] fun_ord_def split_paired_All prod.case case_prod_eta, OF refl]) blast+

lemma inline1_fixp_induct_strong2 [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λinline1'. P (λgpv s. inline1' (gpv, s)))"
  and "P (λ_ _. return_pmf None)"
  and "inline1'. 
     gpv s. ord_spmf (=) (inline1' gpv s) (inline1 gpv s); 
      gpv s. ord_spmf (=) (inline1' gpv s) (the_gpv gpv  case_generat (λx. return_spmf (Inl (x, s))) (λout rpv. the_gpv (callee s out)  case_generat (λ(x, y). inline1' (rpv x) y) (λout rpv'. return_spmf (Inr (out, rpv', rpv)))));
      P inline1' 
     P (λgpv s. the_gpv gpv  case_generat (λx. return_spmf (Inl (x, s))) (λout rpv. the_gpv (callee s out)  case_generat (λ(x, y). inline1' (rpv x) y) (λout rpv'. return_spmf (Inr (out, rpv', rpv)))))"
  shows "P inline1"
using assms
by(rule spmf.fixp_induct_strong2_uc[where P="λf. P (curry f)" and U=case_prod and C=curry, OF inline1.mono inline1_def, simplified curry_case_prod, simplified curry_conv[abs_def] fun_ord_def split_paired_All prod.case case_prod_eta, OF refl]) blast+

text ‹
  Iterate @{const inline1} over all interactions. We'd like to use @{const bind_gpv} before
  the recursive call, but primcorec does not support this. So we emulate @{const bind_gpv}
  by effectively defining two mutually recursive functions (sum type in the argument)
  where the second is exactly @{const bind_gpv} specialised to call inline› in the bind.
›

primcorec inline_aux
  :: "('a, 'call, 'ret) gpv × 's + ('ret  ('a, 'call, 'ret) gpv) × ('ret × 's, 'call', 'ret') gpv
   ('a × 's, 'call', 'ret') gpv"
where
  "state. the_gpv (inline_aux state) =
  (case state of Inl (c, s)  map_spmf (λresult. 
     case result of Inl (x, s)  Pure (x, s)
     | Inr (out, oracle, rpv)  IO out (λinput. inline_aux (Inr (rpv, oracle input)))) (inline1 c s)
  | Inr (rpv, c)   
    map_spmf (λresult. 
       case result of Inl (Inl (x, s))  Pure (x, s)
       | Inl (Inr (out, oracle, rpv))  IO out (λinput. inline_aux (Inr (rpv, oracle input)))
       | Inr (out, c)  IO out (λinput. inline_aux (Inr (rpv, c input))))
  (bind_spmf (the_gpv c) (λgenerat. case generat of Pure (x, s')  (map_spmf Inl (inline1 (rpv x) s'))
     | IO out c  return_spmf (Inr (out, c)))
     ))"

declare inline_aux.simps[simp del]

definition inline :: "('a, 'call, 'ret) gpv  's  ('a × 's, 'call', 'ret') gpv"
where "inline c s = inline_aux (Inl (c, s))"

lemma inline_aux_Inr:
  "inline_aux (Inr (rpv, oracl)) = bind_gpv oracl (λ(x, s). inline (rpv x) s)"
unfolding inline_def
apply(coinduction arbitrary: oracl rule: gpv.coinduct_strong)
apply(simp add: inline_aux.sel bind_gpv.sel spmf_rel_map del: bind_gpv_sel')
apply(rule rel_spmf_bindI[where R="(=)"])
apply(auto simp add: spmf_rel_map inline_aux.sel rel_spmf_reflI generat.rel_map generat.rel_refl rel_fun_def split: generat.split)
done

lemma inline_sel:
  "the_gpv (inline c s) = 
   map_spmf (λresult. case result of Inl xs  Pure xs
                       | Inr (out, oracle, rpv)  IO out (λinput. bind_gpv (oracle input) (λ(x, s'). inline (rpv x) s'))) (inline1 c s)"
by(simp add: inline_def inline_aux.sel inline_aux_Inr cong del: sum.case_cong)

lemma inline1_Fail [simp]: "inline1 Fail s = return_pmf None"
by(rewrite inline1.simps) simp

lemma inline_Fail [simp]: "inline Fail s = Fail"
by(rule gpv.expand)(simp add: inline_sel)

lemma inline1_Done [simp]: "inline1 (Done x) s = return_spmf (Inl (x, s))"
by(rewrite inline1.simps) simp

lemma inline_Done [simp]: "inline (Done x) s = Done (x, s)"
by(rule gpv.expand)(simp add: inline_sel)

lemma inline1_lift_spmf [simp]: "inline1 (lift_spmf p) s = map_spmf (λx. Inl (x, s)) p"
by(rewrite inline1.simps)(simp add: bind_map_spmf o_def map_spmf_conv_bind_spmf)

lemma inline_lift_spmf [simp]: "inline (lift_spmf p) s = lift_spmf (map_spmf (λx. (x, s)) p)"
by(rule gpv.expand)(simp add: inline_sel spmf.map_comp o_def)

lemma inline1_Pause:
  "inline1 (Pause out c) s = 
  the_gpv (callee s out)  (λreact. case react of Pure (x, s')  inline1 (c x) s' | IO out' c'  return_spmf (Inr (out', c', c)))"
by(rewrite inline1.simps) simp

lemma inline_Pause [simp]:
  "inline (Pause out c) s = callee s out  (λ(x, s'). inline (c x) s')"
by(rule gpv.expand)(auto simp add: inline_sel inline1_Pause map_spmf_bind_spmf bind_gpv.sel o_def[abs_def] spmf.map_comp generat.map_comp id_def generat.map_id[unfolded id_def] simp del: bind_gpv_sel' intro!: bind_spmf_cong[OF refl] split: generat.split)

lemma inline1_bind_gpv:
  fixes gpv f s
  defines [simp]: "inline11  inline1" and [simp]: "inline12  inline1" and [simp]: "inline13  inline1"
  shows "inline11 (bind_gpv gpv f) s = bind_spmf (inline12 gpv s) 
    (λres. case res of Inl (x, s')  inline13 (f x) s' | Inr (out, rpv', rpv)  return_spmf (Inr (out, rpv', bind_rpv rpv f)))"
  (is "?lhs = ?rhs")
proof(rule spmf.leq_antisym)
  note [intro!] = ord_spmf_bind_reflI and [split] = generat.split
  show "ord_spmf (=) ?lhs ?rhs" unfolding inline11_def
  proof(induction arbitrary: gpv s f rule: inline1_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step inline1')
    show ?case unfolding inline12_def
      apply(rewrite inline1.simps; clarsimp simp add: bind_rpv_def)
      apply(rule conjI; clarsimp)
      subgoal premises Pure for x
        apply(rewrite inline1.simps; clarsimp)
        subgoal for out c ret s' using step.IH[of "Done x" "λ_. c ret" s'] by simp
        done
      subgoal for out c ret s' using step.IH[of "c ret" f s'] by(simp cong del: sum.case_cong_weak)
      done
  qed
  show "ord_spmf (=) ?rhs ?lhs" unfolding inline12_def
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step inline1')
    show ?case unfolding inline11_def
      apply(rewrite inline1.simps; clarsimp simp add: bind_rpv_def)
      apply(rule conjI; clarsimp)
      subgoal by(rewrite inline1.simps; simp)
      subgoal for out c ret s' using step.IH[of "c ret" s'] by(simp cong del: sum.case_cong_weak)
      done
  qed
qed

lemma inline_bind_gpv [simp]:
  "inline (bind_gpv gpv f) s = bind_gpv (inline gpv s) (λ(x, s'). inline (f x) s')"
apply(coinduction arbitrary: gpv s rule: gpv_coinduct_bind)
apply(clarsimp simp add: map_spmf_bind_spmf o_def[abs_def] bind_gpv.sel inline_sel bind_map_spmf inline1_bind_gpv simp del: bind_gpv_sel' intro!: rel_spmf_bind_reflI split: generat.split)
apply(rule conjI)
 subgoal by(auto split: sum.split_asm simp add: spmf_rel_map spmf.map_comp o_def generat.map_comp generat.map_id[unfolded id_def] spmf.map_id[unfolded id_def] inline_sel intro!: rel_spmf_reflI generat.rel_refl fun.rel_refl)
by(auto split: sum.split_asm simp add: bind_gpv_assoc split_def intro!: gpv.rel_refl exI disjI2 rel_funI)

end

lemma set_inline1_lift_spmf1: "set_spmf (inline1 (λs x. lift_spmf (p s x)) gpv s)  range Inl"
apply(induction arbitrary: gpv s rule: inline1_fixp_induct)
subgoal by(rule cont_intro ccpo_class.admissible_leI)+
apply(auto simp add: o_def bind_UNION split: generat.split_asm)+
done

lemma in_set_inline1_lift_spmf1: "y  set_spmf (inline1 (λs x. lift_spmf (p s x)) gpv s)  r s'. y = Inl (r, s')"
by(drule set_inline1_lift_spmf1[THEN subsetD]) auto

lemma inline_lift_spmf1:
  fixes p defines "callee  λs c. lift_spmf (p s c)"
  shows "inline callee gpv s = lift_spmf (map_spmf projl (inline1 callee gpv s))"
by(rule gpv.expand)(auto simp add: inline_sel spmf.map_comp callee_def intro!: map_spmf_cong[OF refl] dest: in_set_inline1_lift_spmf1)

context includes lifting_syntax begin
lemma inline1_parametric':
  "((S ===> C ===> rel_gpv'' (rel_prod R S) C' R') ===> rel_gpv'' A C R ===> S
   ===> rel_spmf (rel_sum (rel_prod A S) (rel_prod C' (rel_prod (R' ===> rel_gpv'' (rel_prod R S) C' R') (R ===> rel_gpv'' A C R)))))
  inline1 inline1"
  (is "(_ ===> ?R) _ _")
proof(rule rel_funI)
  note [transfer_rule] = the_gpv_parametric'
  show "?R (inline1 callee) (inline1 callee')" 
    if [transfer_rule]: "(S ===> C ===> rel_gpv'' (rel_prod R S) C' R') callee callee'"
    for callee callee'
    unfolding inline1_def
    by(unfold rel_fun_curry case_prod_curry)(rule fixp_spmf_parametric[OF inline1.mono inline1.mono]; transfer_prover)
qed

lemma inline1_parametric [transfer_rule]:
  "((S ===> C ===> rel_gpv (rel_prod (=) S) C') ===> rel_gpv A C ===> S
   ===> rel_spmf (rel_sum (rel_prod A S) (rel_prod C' (rel_prod (rel_rpv (rel_prod (=) S) C') (rel_rpv A C)))))
  inline1 inline1"
unfolding rel_gpv_conv_rel_gpv'' by(rule inline1_parametric')

lemma inline_parametric':
  notes [transfer_rule] = inline1_parametric' the_gpv_parametric' corec_gpv_parametric'
  shows "((S ===> C ===> rel_gpv'' (rel_prod R S) C' R') ===> rel_gpv'' A C R ===> S ===> rel_gpv'' (rel_prod A S) C' R')
  inline inline"
unfolding inline_def[abs_def] inline_aux_def
(* apply transfer_prover raises loose bound variable *)
apply(rule rel_funI)+
subgoal premises [transfer_rule] by transfer_prover
done

lemma inline_parametric [transfer_rule]:
  "((S ===> C ===> rel_gpv (rel_prod (=) S) C') ===> rel_gpv A C ===> S ===> rel_gpv (rel_prod A S) C')
  inline inline"
unfolding rel_gpv_conv_rel_gpv'' by(rule inline_parametric')
end


text ‹Associativity rule for @{const inline}

context
  fixes callee1 :: "'s1  'c1  ('r1 × 's1, 'c, 'r) gpv"
  and callee2 :: "'s2  'c2  ('r2 × 's2, 'c1, 'r1) gpv"
begin

partial_function (spmf) inline2 :: "('a, 'c2, 'r2) gpv  's2  's1
   ('a × ('s2 × 's1) + 'c × ('r1 × 's1, 'c, 'r) rpv × ('r2 × 's2, 'c1, 'r1) rpv × ('a, 'c2, 'r2) rpv) spmf"
where
  "inline2 gpv s2 s1 =
  bind_spmf (the_gpv gpv)
   (case_generat (λx. return_spmf (Inl (x, s2, s1)))
     (λout rpv. bind_spmf (inline1 callee1 (callee2 s2 out) s1)
       (case_sum (λ((r2, s2), s1). inline2 (rpv r2) s2 s1)
         (λ(x, rpv'', rpv'). return_spmf (Inr (x, rpv'', rpv', rpv))))))"

lemma inline2_fixp_induct [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λinline2. P (λgpv s2 s1. inline2 ((gpv, s2), s1)))"
  and "P (λ_ _ _. return_pmf None)"
  and "inline2'. P inline2' 
       P (λgpv s2 s1. bind_spmf (the_gpv gpv) (λgenerat. case generat of
           Pure x  return_spmf (Inl (x, s2, s1))
         | IO out rpv  bind_spmf (inline1 callee1 (callee2 s2 out) s1) (λlr. case lr of
              Inl ((r2, s2), c)  inline2' (rpv r2) s2 c
           | Inr (x, rpv'', rpv')  return_spmf (Inr (x, rpv'', rpv', rpv)))))"
  shows "P inline2"
using assms unfolding split_def by(rule inline2.fixp_induct[unfolded curry_conv[abs_def] split_def])

lemma inline1_inline_conv_inline2:
  fixes gpv' :: "('r2 × 's2, 'c1, 'r1) gpv"
  shows "inline1 callee1 (inline callee2 gpv s2) s1 = 
  map_spmf (map_sum (λ(x, (s2, s1)). ((x, s2), s1))
    (λ(x, rpv'', rpv', rpv). (x, rpv'', λr1. rpv' r1  (λ(r2, s2). inline callee2 (rpv r2) s2))))
  (inline2 gpv s2 s1)"
  (is "?lhs = ?rhs")
proof(rule spmf.leq_antisym)
  define inline1_1 :: "('s1  'c1  ('r1 × 's1, 'c, 'r) gpv)  ('r2 × 's2, 'c1, 'r1) gpv  's1  _"
    where "inline1_1 = inline1"
  have "ord_spmf (=) ?lhs ?rhs"
    ― ‹We need in the inductive step that the approximation behaves well with @{const bind_gpv}
         because of @{thm [source] inline_aux_Inr}. So we have to thread it through the induction
         and do one half of the proof from @{thm [source] inline1_bind_gpv} again. We cannot inline
         @{thm [source] inline1_bind_gpv} in this proof here because the types are too specific.›
    and "ord_spmf (=) (inline1 callee1 (gpv'  f) s1') 
      (do {
      res  inline1_1 callee1 gpv' s1';
      case res of Inl (x, s')  inline1 callee1 (f x) s'
      | Inr (out, rpv', rpv)  return_spmf (Inr (out, rpv', rpv  f))
    })" for gpv' and f :: "_  ('a × 's2, 'c1, 'r1) gpv" and s1'
  proof(induction arbitrary: gpv s2 s1 gpv' f s1' rule: inline1_fixp_induct_strong2)
    case adm thus ?case
      apply(rule cont_intro)+
      subgoal for a b c d by(cases d; clarsimp)
      done

    case (step inline1')
    note step_IH = step.IH[unfolded inline1_1_def] and step_hyps = step.hyps[unfolded inline1_1_def]
    { case 1
      have inline1: "ord_spmf (=)
         (inline1 callee2 gpv s2  (λlr. case lr of Inl as2  return_spmf (Inl (as2, s1))
            | Inr (out1, rpv', rpv)  the_gpv (callee1 s1 out1)  (λgenerat. case generat of
                Pure (r1, s1)  inline1' (bind_gpv (rpv' r1) (λ(r2, s2). inline callee2 (rpv r2) s2)) s1
              | IO out rpv''  return_spmf (Inr (out, rpv'', λr1. bind_gpv (rpv' r1) (λ(r2, s2). inline callee2 (rpv r2) s2)) ))))
         (the_gpv gpv  (λgenerat. case generat of Pure x  return_spmf (Inl ((x, s2), s1))
            | IO out2 rpv  inline1_1 callee1 (callee2 s2 out2) s1  (λlr. case lr of
                Inl ((r2, s2), s1) 
                   map_spmf (map_sum (λ(x, s2, s1). ((x, s2), s1)) (λ(x, rpv'', rpv', rpv). (x, rpv'', λr1. bind_gpv (rpv' r1) (λ(r2, s2). inline callee2 (rpv r2) s2))))
                     (inline2 (rpv r2) s2 s1)
              | Inr (out, rpv'', rpv') 
                   return_spmf (Inr (out, rpv'', λr1. bind_gpv (rpv' r1) (λ(r2, s2). inline callee2 (rpv r2) s2))))))"
      proof(induction arbitrary: gpv s2 s1 rule: inline1_fixp_induct)
        case step2: (step inline1'')
        note step2_IH = step2.IH[unfolded inline1_1_def]

        show ?case unfolding inline1_1_def
          apply(rewrite in "ord_spmf _ _ " inline1.simps)
          apply(clarsimp intro!: ord_spmf_bind_reflI split: generat.split)
          apply(rule conjI)
          subgoal by(rewrite in "ord_spmf _ _ " inline2.simps)(clarsimp simp add: map_spmf_bind_spmf o_def split: generat.split sum.split intro!: ord_spmf_bind_reflI spmf.leq_trans[OF step2_IH])
          subgoal by(clarsimp intro!: ord_spmf_bind_reflI step_IH[THEN spmf.leq_trans] split: generat.split sum.split simp add: bind_rpv_def)
          done
      qed simp_all
      show ?case
        apply(rewrite in "ord_spmf _  _" inline_sel)
        apply(rewrite in "ord_spmf _ _ " inline2.simps)
        apply(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf o_def intro!: ord_spmf_bind_reflI split: generat.split)
        apply(rule spmf.leq_trans[OF spmf.leq_trans, OF _ inline1])
        apply(auto intro!: ord_spmf_bind_reflI split: sum.split generat.split simp add: inline1_1_def map_spmf_bind_spmf)
        done }
    { case 2
      show ?case unfolding inline1_1_def
        by(rewrite inline1.simps)(auto simp del: bind_gpv_sel' simp add: bind_gpv.sel map_spmf_bind_spmf bind_map_spmf o_def bind_rpv_def intro!: ord_spmf_bind_reflI step_IH(2)[THEN spmf.leq_trans] step_hyps(2) split: generat.split sum.split) }
  qed simp_all
  thus "ord_spmf (=) ?lhs ?rhs" by -

  show "ord_spmf (=) ?rhs ?lhs"
  proof(induction arbitrary: gpv s2 s1 rule: inline2_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step inline2')
    show ?case
      apply(rewrite in "ord_spmf _ _ " inline1.simps)
      apply(rewrite inline_sel)
      apply(rewrite in "ord_spmf _  _" inline1.simps)
      apply(rewrite in "ord_spmf _ _ " inline1.simps)
      apply(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf intro!: ord_spmf_bind_reflI split: generat.split)
      apply(rule conjI)
      subgoal
         apply clarsimp
         apply(rule step.IH[THEN spmf.leq_trans])
         apply(rewrite in "ord_spmf _  _" inline1.simps)
         apply(rewrite inline_sel)
         apply(simp add: bind_map_spmf)
         done
      subgoal by(clarsimp intro!: ord_spmf_bind_reflI split: generat.split sum.split simp add: o_def inline1_bind_gpv bind_rpv_def step.IH)
      done
  qed
qed

lemma inline1_inline_conv_inline2':
  "inline1 (λ(s2, s1) c2. map_gpv (λ((r, s2), s1). (r, s2, s1)) id (inline callee1 (callee2 s2 c2) s1)) gpv (s2, s1) =
   map_spmf (map_sum id (λ(x, rpv'', rpv', rpv). (x, λr. bind_gpv (rpv'' r)
       (λ(r1, s1). map_gpv (λ((r2, s2), s1). (r2, s2, s1)) id (inline callee1 (rpv' r1) s1)), rpv)))
     (inline2 gpv s2 s1)"
  (is "?lhs = ?rhs")
proof(rule spmf.leq_antisym)
  show "ord_spmf (=) ?lhs ?rhs"
  proof(induction arbitrary: gpv s2 s1 rule: inline1_fixp_induct)
    case (step inline1') show ?case
      by(rewrite inline2.simps)(auto simp add: map_spmf_bind_spmf o_def inline_sel gpv.map_sel bind_map_spmf id_def[symmetric] gpv.map_id map_gpv_bind_gpv split_def intro!: ord_spmf_bind_reflI step.IH[THEN spmf.leq_trans] split: generat.split sum.split)
  qed simp_all
  show "ord_spmf (=) ?rhs ?lhs"
  proof(induction arbitrary: gpv s2 s1 rule: inline2_fixp_induct)
   case (step inline2')
   show ?case
     apply(rewrite in "ord_spmf _ _ " inline1.simps)
     apply(clarsimp simp add: map_spmf_bind_spmf bind_rpv_def o_def gpv.map_sel bind_map_spmf inline_sel map_gpv_bind_gpv id_def[symmetric] gpv.map_id split_def split: generat.split sum.split intro!: ord_spmf_bind_reflI)
     apply(rule spmf.leq_trans[OF spmf.leq_trans, OF _ step.IH])
     apply(auto simp add: split_def id_def[symmetric] intro!: ord_spmf_reflI)
     done
  qed simp_all
qed

lemma inline_assoc:
  "inline callee1 (inline callee2 gpv s2) s1 =
   map_gpv (λ(r, s2, s1). ((r, s2), s1)) id (inline (λ(s2, s1) c2. map_gpv (λ((r, s2), s1). (r, s2, s1)) id (inline callee1 (callee2 s2 c2) s1)) gpv (s2, s1))"
proof(coinduction arbitrary: s2 s1 gpv rule: gpv_coinduct_bind[where ?'b = "('r2 × 's2) × 's1" and ?'c = "('r2 × 's2) × 's1"])
  case (Eq_gpv s2 s1 gpv)
  have "gpv2 gpv2' (f :: ('r2 × 's2) × 's1  _) (f' :: ('r2 × 's2) × 's1  _).
          bind_gpv (bind_gpv (rpv'' r) (λ(r1, s1). inline callee1 (rpv' r1) s1)) (λ((r2, s2), s1). inline callee1 (inline callee2 (rpv r2) s2) s1) = gpv2  f 
          bind_gpv (bind_gpv (rpv'' r) (λ(r1, s1). inline callee1 (rpv' r1) s1)) (λ((r2, s2), s1). map_gpv (λ(r, s2, y). ((r, s2), y)) id (inline (λ(s2, s1) c2. map_gpv (λ((r, s2), s1). (r, s2, s1)) id (inline callee1 (callee2 s2 c2) s1)) (rpv r2) (s2, s1))) = gpv2'  f' 
          rel_gpv (λx y. s2 s1 gpv. f x = inline callee1 (inline callee2 gpv s2) s1 
              f' y = map_gpv (λ(r, s2, y). ((r, s2), y)) id (inline (λ(s2, s1) c2. map_gpv (λ((r, s2), s1). (r, s2, s1)) id (inline callee1 (callee2 s2 c2) s1)) gpv (s2, s1)))
            (=) gpv2 gpv2'" 
    for rpv'' :: "('r1 × 's1, 'c, 'r) rpv" and rpv' :: "('r2 × 's2, 'c1, 'r1) rpv" and rpv :: "('a, 'c2, 'r2) rpv" and r :: 'r
    by(auto intro!: exI gpv.rel_refl)
  then show ?case
    apply(subst inline_sel)
    apply(subst gpv.map_sel)
    apply(subst inline_sel)
    apply(subst inline1_inline_conv_inline2)
    apply(subst inline1_inline_conv_inline2')
    apply(unfold spmf.map_comp o_def case_sum_map_sum spmf_rel_map generat.rel_map)
    apply(rule rel_spmf_reflI)
    subgoal for lr by(cases lr)(auto del: disjCI intro!: rel_funI disjI2 simp add: split_def map_gpv_conv_bind[folded id_def] bind_gpv_assoc)
    done
qed

end

lemma set_inline2_lift_spmf1: "set_spmf (inline2 (λs x. lift_spmf (p s x)) callee gpv s s')  range Inl"
apply(induction arbitrary: gpv s s' rule: inline2_fixp_induct)
subgoal by(rule cont_intro ccpo_class.admissible_leI)+
apply(auto simp add: o_def bind_UNION split: generat.split_asm sum.split_asm dest!: in_set_inline1_lift_spmf1)
apply blast
done

lemma in_set_inline2_lift_spmf1: "y  set_spmf (inline2 (λs x. lift_spmf (p s x)) callee gpv s s')  r s s'. y = Inl (r, s, s')"
by(drule set_inline2_lift_spmf1[THEN subsetD]) auto

context
  fixes consider' :: "'call  bool" 
  and "consider" :: "'call'  bool" 
  and callee :: "'s  'call  ('ret × 's, 'call', 'ret') gpv" 
  notes [[function_internals]]
begin

private partial_function (spmf) inline1'
  :: "('a, 'call, 'ret) gpv  's
   ('a × 's + 'call × 'call' × ('ret × 's, 'call', 'ret') rpv × ('a, 'call, 'ret) rpv) spmf"
where
  "inline1' gpv s =
   the_gpv gpv 
   case_generat (λx. return_spmf (Inl (x, s)))
     (λout rpv. the_gpv (callee s out) 
         case_generat (λ(x, y). inline1' (rpv x) y)
          (λout' rpv'. return_spmf (Inr (out, out', rpv', rpv))))"

private lemma inline1'_fixp_induct [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λinline1'. P (λgpv s. inline1' (gpv, s)))"
  and "P (λ_ _. return_pmf None)"
  and "inline1'. P inline1'  P (λgpv s. the_gpv gpv  case_generat (λx. return_spmf (Inl (x, s))) (λout rpv. the_gpv (callee s out)  case_generat (λ(x, y). inline1' (rpv x) y) (λout' rpv'. return_spmf (Inr (out, out', rpv', rpv)))))"
  shows "P inline1'"
using assms by(rule inline1'.fixp_induct[unfolded curry_conv[abs_def]])

private lemma inline1_conv_inline1': "inline1 callee gpv s = map_spmf (map_sum id snd) (inline1' gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1'.mono inline1_def inline1'_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1 inline1')
  thus ?case by(clarsimp simp add: map_spmf_bind_spmf o_def intro!: bind_spmf_cong[OF refl] split: generat.split)
qed

context
  fixes q :: "enat" 
  assumes q: "s x. consider' x  interaction_bound consider (callee s x)  q"
  and ignore: "s x. ¬ consider' x  interaction_bound consider (callee s x) = 0"
begin

private lemma interaction_bound_inline1'_aux:
  "interaction_bound consider' gpv  p
   set_spmf (inline1' gpv s)  {Inr (out', out, c', rpv) | out' out c' rpv. 
        if consider' out'
        then (input. (if consider out then eSuc (interaction_bound consider (c' input)) else interaction_bound consider (c' input))  q)  
             (x. eSuc (interaction_bound consider' (rpv x))  p)
        else ¬ consider out  (input. interaction_bound consider (c' input) = 0)  (x. interaction_bound consider' (rpv x)  p)}
       range Inl"
proof(induction arbitrary: gpv s rule: inline1'_fixp_induct)
  { case adm show ?case by(rule cont_intro ccpo_class.admissible_leI)+ }
  { case bottom show ?case by simp }
  case (step inline1')
  have *: "interaction_bound consider' (c input)  p" if "IO out c  set_spmf (the_gpv gpv)" for out c input
    by(cases "consider' out")(auto intro: interaction_bound_IO_consider[OF that, THEN order_trans, THEN order_trans[OF ile_eSuc]] interaction_bound_IO_ignore[OF that, THEN order_trans] step.prems)
  have **: "if consider' out'
    then (input. (if consider out then eSuc (interaction_bound consider (c input))  else interaction_bound consider (c input))  q) 
         (x. eSuc (interaction_bound consider' (rpv x))  p)
    else ¬ consider out  (input. interaction_bound consider (c input) = 0)  (x. interaction_bound consider' (rpv x)  p)"
    if "IO out' rpv  set_spmf (the_gpv gpv)" "IO out c  set_spmf (the_gpv (callee s out'))"
    for out' rpv out c
  proof(cases "consider' out'")
    case True
    then show ?thesis using that q
      by(auto split del: if_split intro!: interaction_bound_IO[THEN order_trans] interaction_bound_IO_consider[THEN order_trans] step.prems)
  next
    case False
    have "¬ consider out" "interaction_bound consider (c input) = 0" for input
      using interaction_bound_IO[OF that(2), of "consider" input] ignore[OF False, of s]
      by(auto split: if_split_asm)
    then show ?thesis using False that
      by(auto split del: if_split intro: interaction_bound_IO_ignore[THEN order_trans] step.prems)
  qed
  show ?case
    by(auto 6 4 simp add: bind_UNION del: subsetI intro!: UN_least intro: step.IH * ** split: generat.split split del: if_split)
qed

lemma interaction_bound_inline1':
  " Inr (out', out, c', rpv)  set_spmf (inline1' gpv s); interaction_bound consider' gpv  p 
   if consider' out' then
        (if consider out then eSuc (interaction_bound consider (c' input)) else interaction_bound consider (c' input))  q  
        eSuc (interaction_bound consider' (rpv x))  p
      else ¬ consider out  interaction_bound consider (c' input) = 0  interaction_bound consider' (rpv x)  p"
using interaction_bound_inline1'_aux[where gpv=gpv and p=p and s=s] by(auto split: if_split_asm)

end

lemma interaction_bounded_by_inline1:
  " Inr (out', out, c', rpv)  set_spmf (inline1' gpv s); 
    interaction_bounded_by consider' gpv p;
    s x. consider' x  interaction_bounded_by consider (callee s x) q; 
    s x. ¬ consider' x  interaction_bounded_by consider (callee s x) 0 
   if consider' out' then
        (if consider out then q  0  interaction_bounded_by consider (c' input) (q - 1) else interaction_bounded_by consider (c' input) q) 
        p  0  interaction_bounded_by consider' (rpv x) (p - 1)
      else ¬ consider out  interaction_bounded_by consider (c' input) 0  interaction_bounded_by consider' (rpv x) p"
unfolding interaction_bounded_by_0 unfolding interaction_bounded_by.simps
apply(drule (1) interaction_bound_inline1'[where input=input and x=x, rotated 2], assumption, assumption)
apply(cases p q rule: co.enat.exhaust[case_product co.enat.exhaust])
apply(simp_all add: zero_enat_def[symmetric] eSuc_enat[symmetric] split: if_split_asm)
done

declare enat_0_iff [simp]

lemma interaction_bounded_by_inline [interaction_bound]:
  assumes p: "interaction_bounded_by consider' gpv p"
  and q: "s x. consider' x  interaction_bounded_by consider (callee s x) q"
  and ignore: "s x. ¬ consider' x  interaction_bounded_by consider (callee s x) 0"
  shows "interaction_bounded_by consider (inline callee gpv s) (p * q)"
proof
  have "interaction_bounded_by consider' gpv p  interaction_bound consider (inline callee gpv s)  p * q"
    and "interaction_bound consider (bind_gpv gpv' f)  interaction_bound consider gpv' + (SUP xresults'_gpv gpv'. interaction_bound consider (f x))"
    for gpv' and f :: "'ret × 's  ('a × 's, 'call', 'ret') gpv"
  proof(induction arbitrary: gpv s p gpv' f rule: interaction_bound_fixp_induct)
    case adm show ?case by simp
    case bottom case 1 show ?case by simp
    case (step interaction_bound') case step: 1
    show ?case (is "(SUP generat?inline. ?lhs generat)  ?rhs")
    proof(rule SUP_least)
      fix generat
      assume "generat  ?inline"
      then consider (Pure) ret s' where "generat = Pure (ret, s')"
          and "Inl (ret, s')  set_spmf (inline1 callee gpv s)"
        | (IO) out c rpv where "generat = IO out (λinput. bind_gpv (c input) (λ(ret, s'). inline callee (rpv ret) s'))"
          and "Inr (out, c, rpv)  set_spmf (inline1 callee gpv s)"
        by(clarsimp simp add: inline_sel split: sum.split_asm)
      then show "?lhs generat  ?rhs"
      proof(cases)
        case Pure thus ?thesis by simp
      next
        case IO
        from IO(2) obtain out' where out': "Inr (out', out, c, rpv)  set_spmf (inline1' gpv s)"
          by(auto simp add: inline1_conv_inline1' Inr_eq_map_sum_iff)
        show ?thesis
        proof(cases "consider' out'")
          case True
          with interaction_bounded_by_inline1[OF out' step.prems q ignore]
          have p: "p  0" and rpv: "x. interaction_bounded_by consider' (rpv x) (p - 1)"
            and c: "input. if consider out then q  0  interaction_bounded_by consider (c input) (q - 1) else interaction_bounded_by consider (c input) q"
            by auto

          have "?lhs generat  (if consider out then 1 else 0) + (SUP input. interaction_bound' (bind_gpv (c input) (λ(ret, s'). inline callee (rpv ret) s')))"
            (is "_  _ + ?sup")
            using IO(1) by(auto simp add: plus_1_eSuc)
          also have "?sup  (SUP input. interaction_bound consider (c input) + (SUP (ret, s')  results'_gpv (c input). interaction_bound' (inline callee (rpv ret) s')))"
            unfolding split_def by(rule SUP_mono)(blast intro: step.IH)
          also have "  (SUP input. interaction_bound consider (c input) + (SUP (ret, s')  results'_gpv (c input). (p - 1) * q))"
            using rpv by(auto intro!: SUP_mono rev_bexI add_mono step.IH)
          also have "  (SUP input. interaction_bound consider (c input) + (p - 1) * q)"
            apply(auto simp add: SUP_constant bot_enat_def intro!: SUP_mono)
            apply(metis add.right_neutral add_mono i0_lb order_refl)+
            done
          also have "  (SUP input :: 'ret'. (if consider out then q - 1 else q) + (p - 1) * q)"
            apply(rule SUP_mono rev_bexI UNIV_I add_mono)+
            using c
            apply(auto simp add: interaction_bounded_by.simps)
            done
          also have " = (if consider out then q - 1 else q) + (p - 1) * q"
            by(simp add: SUP_constant)
          finally show ?thesis
            apply(rule order_trans)
            prefer 5
            using p c
            apply(cases p; cases q)
            apply(auto simp add: one_enat_def algebra_simps Suc_leI)
            done
        next
          case False
          with interaction_bounded_by_inline1[OF out' step.prems q ignore]
          have out: "¬ consider out" and zero: "input. interaction_bounded_by consider (c input) 0"
            and rpv: "x. interaction_bounded_by consider' (rpv x) p" by auto
          have "?lhs generat  (SUP input. interaction_bound' (bind_gpv (c input) (λ(ret, s'). inline callee (rpv ret) s')))"
            using IO(1) out by auto
          also have "  (SUP input. interaction_bound consider (c input) + (SUP (ret, s')  results'_gpv (c input). interaction_bound' (inline callee (rpv ret) s')))"
            unfolding split_def by(rule SUP_mono)(blast intro: step.IH)
          also have "  (SUP input. (SUP (ret, s')  results'_gpv (c input). p * q))"
            using rpv zero by(auto intro!: SUP_mono rev_bexI add_mono step.IH simp add: interaction_bounded_by_0)
          also have "  (SUP input :: 'ret'. p * q)"
            by(rule SUP_mono rev_bexI)+(auto simp add: SUP_constant)
          also have " = p * q" by(simp add: SUP_constant)
          finally show ?thesis .
        qed
      qed
    qed
  next
    case bottom case 2 show ?case by simp
    case step case 2 show ?case using step by -(rule interaction_bound_bind_step)
  qed
  then show "interaction_bound consider (inline callee gpv s)  p * q" using p by -
qed

end

lemma interaction_bounded_by_inline_invariant: (* TODO: augment with types *)
  includes lifting_syntax
  fixes consider' :: "'call  bool" 
  and "consider" :: "'call'  bool" 
  and callee :: "'s  'call  ('ret × 's, 'call', 'ret') gpv" 
  and gpv :: "('a, 'call, 'ret) gpv"
  assumes p: "interaction_bounded_by consider' gpv p"
  and q: "s x.  I s; consider' x   interaction_bounded_by consider (callee s x) q"
  and ignore: "s x.  I s; ¬ consider' x   interaction_bounded_by consider (callee s x) 0"
  and I: "I s"
  and invariant: "s x y s'.  (y, s')  results'_gpv (callee s x); I s   I s'"
  shows "interaction_bounded_by consider (inline callee gpv s) (p * q)"
proof -
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr x y  x = Rep y" for x y
    have [transfer_rule]: "bi_unique cr" "right_total cr"
      using td cr_def[abs_def] by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I"
      using type_definition_Domainp[OF td cr_def[abs_def]] by simp

    define callee' where "callee' = (Rep --->  id ---> map_gpv (map_prod id Abs) id) callee"
    have [transfer_rule]: "(cr ===> (=) ===> rel_gpv (rel_prod (=) cr) (=)) callee callee'"
      by(auto simp add: callee'_def rel_fun_def cr_def gpv.rel_map prod.rel_map td.Abs_inverse intro!: gpv.rel_refl_strong intro: td.Rep[simplified] dest: invariant)

    define s' where "s' = Abs s"
    have [transfer_rule]: "cr s s'" using I by(simp add: cr_def s'_def td.Abs_inverse)

    note p moreover
    have "consider' x  interaction_bounded_by consider (callee' s x) q" for s x
      by(transfer fixing: "consider" consider' q)(clarsimp simp add: q)
    moreover have "¬ consider' x  interaction_bounded_by consider (callee' s x) 0" for s x
      by(transfer fixing: "consider" consider')(clarsimp simp add: ignore)
    ultimately have "interaction_bounded_by consider (inline callee' gpv s') (p * q)" 
      by(rule interaction_bounded_by_inline)
    then have "interaction_bounded_by consider (inline callee gpv s) (p * q)" by transfer  }
  from this[cancel_type_definition] I show ?thesis by blast
qed

context
  fixes  :: "('call, 'ret) ℐ"
  and ℐ' :: "('call', 'ret') ℐ"
  and callee :: "'s  'call  ('ret × 's, 'call', 'ret') gpv"
  assumes results: "s x. x  outs_ℐ   results_gpv ℐ' (callee s x)  responses_ℐ  x × UNIV"
begin

lemma inline1_in_sub_gpvs_callee:
  assumes "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
  and WT: " ⊢g gpv "
  shows "callouts_ℐ . s. x  responses_ℐ ℐ' out. callee' x  sub_gpvs ℐ' (callee s call)"
proof -
  from WT
  have "set_spmf (inline1 callee gpv s)  {Inr (out, callee', rpv') | out callee' rpv'.
    callouts_ℐ . s. x  responses_ℐ ℐ' out. callee' x  sub_gpvs ℐ' (callee s call)}  range Inl"
    (is "?concl (inline1 callee) gpv s")
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
    case bottom show ?case by simp
    case (step inline1')
    { fix out c
      assume IO: "IO out c  set_spmf (the_gpv gpv)" 
      from step.prems IO have out: "out  outs_ℐ " by(rule WT_gpvD)
      { fix x s'
        assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
        then have "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
        with out have "x  responses_ℐ  out" by(auto dest: results)
        with step.prems IO have " ⊢g c x " by(rule WT_gpvD)
        hence "?concl inline1' (c x) s'" by(rule step.IH)
      } moreover {
        fix out' c'
        assume "IO out' c'  set_spmf (the_gpv (callee s out))"
        hence "xresponses_ℐ ℐ' out'. c' x  sub_gpvs ℐ' (callee s out)"
          by(auto intro: sub_gpvs.base)
        then have "callouts_ℐ . s. xresponses_ℐ ℐ' out'. c' x  sub_gpvs ℐ' (callee s call)"
          using out by blast
      } moreover note calculation }
    then show ?case using step.prems
      by(auto del: subsetI simp add: bind_UNION intro!: UN_least split: generat.split)
  qed
  thus ?thesis using assms by fastforce
qed

lemma inline1_in_sub_gpvs:
  assumes "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
  and "(x, s')  results_gpv ℐ' (callee' input)"
  and "input  responses_ℐ ℐ' out"
  and " ⊢g gpv "
  shows "rpv' x  sub_gpvs  gpv"
proof -
  from  ⊢g gpv 
  have "set_spmf (inline1 callee gpv s)  {Inr (out, callee', rpv') | out callee' rpv'.
    input  responses_ℐ ℐ' out. (x, s')results_gpv ℐ' (callee' input). rpv' x  sub_gpvs  gpv}
     range Inl" (is "?concl (inline1 callee) gpv s" is "_  ?rhs gpv s")
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
    case bottom show ?case by simp
  next
    case (step inline1')
    { fix out c
      assume IO: "IO out c  set_spmf (the_gpv gpv)" 
      from step.prems IO have out: "out  outs_ℐ " by(rule WT_gpvD)
      { fix x s'
        assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
        then have "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
        with out have "x  responses_ℐ  out" by(auto dest: results)
        with step.prems IO have " ⊢g c x " by(rule WT_gpvD)
        hence "?concl inline1' (c x) s'" by(rule step.IH)
        also have "  ?rhs gpv s'" using IO Pure
          by(fastforce intro: sub_gpvs.cont dest: WT_gpv_OutD[OF step.prems] results[THEN subsetD, OF _ results_gpv.Pure])
        finally have "set_spmf (inline1' (c x) s')  " .
      } moreover {
        fix out' c' input x s'
        assume "IO out' c'  set_spmf (the_gpv (callee s out))"
          and "input  responses_ℐ ℐ' out'" and "(x, s')  results_gpv ℐ' (c' input)"
        then have "c x  sub_gpvs  gpv" using IO
          by(auto intro!: sub_gpvs.base dest: WT_gpv_OutD[OF step.prems] results[THEN subsetD, OF _ results_gpv.IO])
      } moreover note calculation }
    then show ?case
      by(auto simp add: bind_UNION intro!: UN_least split: generat.split del: subsetI)
  qed
  with assms show ?thesis by fastforce
qed

context
  assumes WT: "x s. x  outs_ℐ   ℐ' ⊢g callee s x "
begin

lemma WT_gpv_inline1:
  assumes "Inr (out, rpv, rpv')  set_spmf (inline1 callee gpv s)"
  and " ⊢g gpv "
  shows "out  outs_ℐ ℐ'" (is "?thesis1")
  and "input  responses_ℐ ℐ' out  ℐ' ⊢g rpv input " (is "PROP ?thesis2")
  and " input  responses_ℐ ℐ' out; (x, s')  results_gpv ℐ' (rpv input)    ⊢g rpv' x " (is "PROP ?thesis3")
proof -
  from  ⊢g gpv 
  have "set_spmf (inline1 callee gpv s)  {Inr (out, rpv, rpv') | out rpv rpv'. out  outs_ℐ ℐ'}  range Inl"
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    { case adm show ?case by(intro cont_intro ccpo_class.admissible_leI) }
    { case bottom show ?case by simp }
    case (step inline1')
    { fix out c
      assume IO: "IO out c  set_spmf (the_gpv gpv)" 
      from step.prems IO have out: "out  outs_ℐ " by(rule WT_gpvD)
      { fix x s'
        assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
        then have "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
        with out have "x  responses_ℐ  out" by(auto dest: results)
        with step.prems IO have " ⊢g c x " by(rule WT_gpvD)
      } moreover {
        fix out' c'
        from out have "ℐ' ⊢g callee s out " by(rule WT)
        moreover assume "IO out' c'  set_spmf (the_gpv (callee s out))"
        ultimately have "out'  outs_ℐ ℐ'" by(rule WT_gpvD) 
      } moreover note calculation }
    then show ?case 
      by(auto del: subsetI simp add: bind_UNION intro!: UN_least split: generat.split intro!: step.IH[THEN order_trans])
  qed
  then show ?thesis1 using assms by auto

  assume "input  responses_ℐ ℐ' out"
  with inline1_in_sub_gpvs_callee[OF ‹Inr _  _]  ⊢g gpv 
  obtain out' s where "out'  outs_ℐ " 
    and *: "rpv input  sub_gpvs ℐ' (callee s out')" by auto
  from out'  _ have "ℐ' ⊢g callee s out' " by(rule WT)
  then show "ℐ' ⊢g rpv input " using * by(rule WT_sub_gpvsD)

  assume "(x, s')  results_gpv ℐ' (rpv input)"
  with ‹Inr _  _ have "rpv' x  sub_gpvs  gpv"
    using input  _  ⊢g gpv  by(rule inline1_in_sub_gpvs)
  with  ⊢g gpv  show " ⊢g rpv' x " by(rule WT_sub_gpvsD)
qed

lemma WT_gpv_inline:
  assumes " ⊢g gpv "
  shows "ℐ' ⊢g inline callee gpv s "
using assms
proof(coinduction arbitrary: gpv s rule: WT_gpv_coinduct_bind)
  case (WT_gpv out c gpv)
  from ‹IO out c  _ obtain callee' rpv'
    where Inr: "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
    and c: "c = (λinput. callee' input  (λ(x, s). inline callee (rpv' x) s))"
    by(clarsimp simp add: inline_sel split: sum.split_asm)
  from Inr  ⊢g gpv  have ?out by(rule WT_gpv_inline1)
  moreover have "?cont TYPE('ret × 's)" (is "input_. _  _  ?case' input")
  proof(rule ballI disjI2)+
    fix input
    assume "input  responses_ℐ ℐ' out"
    with Inr  ⊢g gpv have "ℐ' ⊢g callee' input "
      and "x s'. (x, s')  results_gpv ℐ' (callee' input)   ⊢g rpv' x "
      by(blast intro: WT_gpv_inline1)+
    then show "?case' input" by(subst c)(auto 4 4)
  qed
  ultimately show "?case TYPE('ret × 's)" ..
qed

end

context
  fixes gpv :: "('a, 'call, 'ret) gpv"
  assumes gpv: "lossless_gpv  gpv" " ⊢g gpv "
begin

lemma lossless_spmf_inline1:
  assumes lossless: "s x. x  outs_ℐ   lossless_spmf (the_gpv (callee s x))"
  shows "lossless_spmf (inline1 callee gpv s)"
using gpv
proof(induction arbitrary: s rule: lossless_WT_gpv_induct)
  case (lossless_gpv p)
  show ?case using ‹lossless_spmf p
    apply(subst inline1_unfold)
    apply(auto split: generat.split intro: lossless lossless_gpv.hyps dest: results[THEN subsetD, rotated, OF results_gpv.Pure] intro: lossless_gpv.IH)
    done
qed

lemma lossless_gpv_inline1:
  assumes *: "Inr (out, rpv, rpv')  set_spmf (inline1 callee gpv s)"
  and **: "input  responses_ℐ ℐ' out"
  and lossless: "s x. x  outs_ℐ   lossless_gpv ℐ' (callee s x)"
  shows "lossless_gpv ℐ' (rpv input)"
proof -
  from inline1_in_sub_gpvs_callee[OF * gpv(2)] **
  obtain out' s where "out'  outs_ℐ " and ***: "rpv input  sub_gpvs ℐ' (callee s out')" by blast
  from out'  _ have "lossless_gpv ℐ' (callee s out')" by(rule lossless)
  thus ?thesis using *** by(rule lossless_sub_gpvsD)
qed

lemma lossless_results_inline1:
  assumes "Inr (out, rpv, rpv')  set_spmf (inline1 callee gpv s)"
  and "(x, s')  results_gpv ℐ' (rpv input)"
  and "input  responses_ℐ ℐ' out"
  shows "lossless_gpv  (rpv' x)"
proof -
  from assms gpv(2) have "rpv' x  sub_gpvs  gpv" by(rule inline1_in_sub_gpvs)
  with gpv(1) show "lossless_gpv  (rpv' x)" by(rule lossless_sub_gpvsD)
qed

end

lemmas lossless_inline1[rotated 2] = lossless_spmf_inline1 lossless_gpv_inline1 lossless_results_inline1

lemma lossless_inline[rotated]:
  fixes gpv :: "('a, 'call, 'ret) gpv"
  assumes gpv: "lossless_gpv  gpv" " ⊢g gpv "
  and lossless: "s x. x  outs_ℐ   lossless_gpv ℐ' (callee s x)"
  shows "lossless_gpv ℐ' (inline callee gpv s)"
using gpv
proof(induction arbitrary: s rule: lossless_WT_gpv_induct_strong)
  case (lossless_gpv p)
  have lp: "lossless_gpv  (GPV p)" by(rule lossless_sub_gpvsI)(auto intro: lossless_gpv.hyps)
  moreover have wp: " ⊢g GPV p " by(rule WT_sub_gpvsI)(auto intro: lossless_gpv.hyps)
  ultimately have "lossless_spmf (the_gpv (inline callee (GPV p) s))"
    by(auto simp add: inline_sel intro: lossless_spmf_inline1 lossless_gpv_lossless_spmfD[OF lossless])
  moreover {
    fix out c input
    assume IO: "IO out c  set_spmf (the_gpv (inline callee (GPV p) s))"
      and "input  responses_ℐ ℐ' out"
    from IO obtain callee' rpv
      where Inr: "Inr (out, callee', rpv)  set_spmf (inline1 callee (GPV p) s)"
      and c: "c = (λinput. callee' input  (λ(x, y). inline callee (rpv x) y))"
      by(clarsimp simp add: inline_sel split: sum.split_asm)
    from Inr input  _ lossless lp wp have "lossless_gpv ℐ' (callee' input)" by(rule lossless_inline1)
    moreover {
      fix x s'
      assume "(x, s')  results_gpv ℐ' (callee' input)"
      with Inr have "rpv x  sub_gpvs  (GPV p)" using input  _ wp by(rule inline1_in_sub_gpvs)
      hence "lossless_gpv ℐ' (inline callee (rpv x) s')" by(rule lossless_gpv.IH)
    } ultimately have "lossless_gpv ℐ' (c input)" unfolding c by clarsimp
  } ultimately show ?case by(rule lossless_gpvI)
qed

end

definition id_oracle :: "'s  'call  ('ret × 's, 'call, 'ret) gpv"
where "id_oracle s x = Pause x (λx. Done (x, s))"

lemma inline1_id_oracle:
  "inline1 id_oracle gpv s =
   map_spmf (λgenerat. case generat of Pure x  Inl (x, s) | IO out c  Inr (out, λx. Done (x, s), c)) (the_gpv gpv)"
by(subst inline1.simps)(auto simp add: id_oracle_def map_spmf_conv_bind_spmf intro!: bind_spmf_cong split: generat.split)

lemma inline_id_oracle [simp]: "inline id_oracle gpv s = map_gpv (λx. (x, s)) id gpv"
by(coinduction arbitrary: gpv s)(auto 4 3 simp add: inline_sel inline1_id_oracle spmf_rel_map gpv.map_sel o_def generat.rel_map intro!: rel_spmf_reflI rel_funI split: generat.split)

locale raw_converter_invariant =
  fixes  :: "('call, 'ret) ℐ"
    and ℐ' :: "('call', 'ret') ℐ"
    and callee :: "'s  'call  ('ret × 's, 'call', 'ret') gpv"
    and I :: "'s  bool"
  assumes results_callee: "s x.  x  outs_ℐ ; I s   results_gpv ℐ' (callee s x)  responses_ℐ  x × {s. I s}"
    and WT_callee: "x s.  x  outs_ℐ ; I s   ℐ' ⊢g callee s x "
begin

context begin
private lemma aux:
  "set_spmf (inline1 callee gpv s)  {Inr (out, callee', rpv') | out callee' rpv'.
    callouts_ℐ . s. I s  (x  responses_ℐ ℐ' out. callee' x  sub_gpvs ℐ' (callee s call))} 
     {Inl (x, s') | x s'. x  results_gpv  gpv  I s'}"
  (is "?concl (inline1 callee) gpv s" is "_  ?rhs1  ?rhs2 gpv")
  if " ⊢g gpv " "I s"
  using that
proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)" 
    from step.prems(1) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
    { fix x s'
      assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
      then have "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
      with out step.prems(2) have "x  responses_ℐ  out" "I s'" by(auto dest: results_callee)
      from step.prems(1) IO this(1) have " ⊢g c x " by(rule WT_gpvD)
      hence "?concl inline1' (c x) s'" using I s' by(rule step.IH)
      also have "  ?rhs1  ?rhs2 gpv" using x  _ IO by(auto intro: results_gpv.intros)
      also note calculation
    } moreover {
      fix out' c'
      assume "IO out' c'  set_spmf (the_gpv (callee s out))"
      hence "xresponses_ℐ ℐ' out'. c' x  sub_gpvs ℐ' (callee s out)"
        by(auto intro: sub_gpvs.base)
      then have "callouts_ℐ . s. I s  (xresponses_ℐ ℐ' out'. c' x  sub_gpvs ℐ' (callee s call))"
        using out step.prems(2) by blast
    } moreover note calculation }
    then show ?case using step.prems
      by(auto 4 3 del: subsetI simp add: bind_UNION intro!: UN_least split: generat.split intro: results_gpv.intros)
  qed

lemma inline1_in_sub_gpvs_callee:
  assumes "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
    and WT: " ⊢g gpv "
    and s: "I s"
  shows "callouts_ℐ . s. I s  (x  responses_ℐ ℐ' out. callee' x  sub_gpvs ℐ' (callee s call))"
  using aux[OF WT s] assms(1) by fastforce

lemma inline1_Inl_results_gpv:
  assumes "Inl (x, s')  set_spmf (inline1 callee gpv s)"
    and WT: " ⊢g gpv "
    and s: "I s"
  shows "x  results_gpv  gpv  I s'"
  using aux[OF WT s] assms(1) by fastforce
end

lemma inline1_in_sub_gpvs:
  assumes "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
    and "(x, s')  results_gpv ℐ' (callee' input)"
    and "input  responses_ℐ ℐ' out"
    and " ⊢g gpv "
    and "I s"
  shows "rpv' x  sub_gpvs  gpv  I s'"
proof -
  from  ⊢g gpv  I s
  have "set_spmf (inline1 callee gpv s)  {Inr (out, callee', rpv') | out callee' rpv'.
    input  responses_ℐ ℐ' out. (x, s')results_gpv ℐ' (callee' input). I s'  rpv' x  sub_gpvs  gpv}
     {Inl (x, s') | x s'. I s'}" (is "?concl (inline1 callee) gpv s" is "_  ?rhs gpv s")
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
    case bottom show ?case by simp
    case (step inline1')
    { fix out c
      assume IO: "IO out c  set_spmf (the_gpv gpv)" 
      from step.prems(1) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
      { fix x s'
        assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
        then have "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
        with out step.prems(2) have "x  responses_ℐ  out" "I s'" by(auto dest: results_callee)
        from step.prems(1) IO this(1) have " ⊢g c x " by(rule WT_gpvD)
        hence "?concl inline1' (c x) s'" using I s' by(rule step.IH)
        also have "  ?rhs gpv s'" using IO Pure I s
          by(fastforce intro: sub_gpvs.cont dest: WT_gpv_OutD[OF step.prems(1)] results_callee[THEN subsetD, OF _ _ results_gpv.Pure])
        finally have "set_spmf (inline1' (c x) s')  " .
      } moreover {
        fix out' c' input x s'
        assume "IO out' c'  set_spmf (the_gpv (callee s out))"
          and "input  responses_ℐ ℐ' out'" and "(x, s')  results_gpv ℐ' (c' input)"
        then have "c x  sub_gpvs  gpv" "I s'" using IO I s
          by(auto intro!: sub_gpvs.base dest: WT_gpv_OutD[OF step.prems(1)] results_callee[THEN subsetD, OF _ _ results_gpv.IO])
      } moreover note calculation }
      then show ?case using step.prems(2)
        by(auto simp add: bind_UNION intro!: UN_least split: generat.split del: subsetI)
    qed
    with assms show ?thesis by fastforce
  qed

lemma WT_gpv_inline1:
  assumes "Inr (out, rpv, rpv')  set_spmf (inline1 callee gpv s)"
    and " ⊢g gpv "
    and "I s"
  shows "out  outs_ℐ ℐ'" (is "?thesis1")
    and "input  responses_ℐ ℐ' out  ℐ' ⊢g rpv input " (is "PROP ?thesis2")
    and " input  responses_ℐ ℐ' out; (x, s')  results_gpv ℐ' (rpv input)    ⊢g rpv' x   I s'" (is "PROP ?thesis3")
proof -
  from  ⊢g gpv  I s
  have "set_spmf (inline1 callee gpv s)  {Inr (out, rpv, rpv') | out rpv rpv'. out  outs_ℐ ℐ'}  {Inl (x, s')| x s'. I s'}"
  proof(induction arbitrary: gpv s rule: inline1_fixp_induct)
    { case adm show ?case by(intro cont_intro ccpo_class.admissible_leI) }
    { case bottom show ?case by simp }
    case (step inline1')
    { fix out c
      assume IO: "IO out c  set_spmf (the_gpv gpv)" 
      from step.prems(1) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
      { fix x s'
        assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
        then have *: "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
        with out step.prems(2) have "x  responses_ℐ  out" "I s'" by(auto dest: results_callee)
        from step.prems(1) IO this(1) have " ⊢g c x " by(rule WT_gpvD)
        note this I s'
      } moreover {
        fix out' c'
        from out step.prems(2) have "ℐ' ⊢g callee s out " by(rule WT_callee)
        moreover assume "IO out' c'  set_spmf (the_gpv (callee s out))"
        ultimately have "out'  outs_ℐ ℐ'" by(rule WT_gpvD) 
      } moreover note calculation }
      then show ?case using step.prems(2)
        by(auto del: subsetI simp add: bind_UNION intro!: UN_least split: generat.split intro!: step.IH[THEN order_trans])
    qed
    then show ?thesis1 using assms by auto

    assume "input  responses_ℐ ℐ' out"
    with inline1_in_sub_gpvs_callee[OF ‹Inr _  _  ⊢g gpv  I s]
    obtain out' s where "out'  outs_ℐ " 
      and *: "rpv input  sub_gpvs ℐ' (callee s out')" and "I s" by blast
    from out'  _ I s have "ℐ' ⊢g callee s out' " by(rule WT_callee)
    then show "ℐ' ⊢g rpv input " using * by(rule WT_sub_gpvsD)

    assume "(x, s')  results_gpv ℐ' (rpv input)"
    with ‹Inr _  _ have "rpv' x  sub_gpvs  gpv  I s'"
      using input  _  ⊢g gpv  assms(3) I s by-(rule inline1_in_sub_gpvs)
    with  ⊢g gpv  show " ⊢g rpv' x   I s'" by(blast intro: WT_sub_gpvsD)
  qed

lemma WT_gpv_inline_invar:
  assumes " ⊢g gpv "
    and "I s"
  shows "ℐ' ⊢g inline callee gpv s "
  using assms
proof(coinduction arbitrary: gpv s rule: WT_gpv_coinduct_bind)
  case (WT_gpv out c gpv)
  from ‹IO out c  _ obtain callee' rpv'
    where Inr: "Inr (out, callee', rpv')  set_spmf (inline1 callee gpv s)"
      and c: "c = (λinput. callee' input  (λ(x, s). inline callee (rpv' x) s))"
    by(clarsimp simp add: inline_sel split: sum.split_asm)
  from Inr  ⊢g gpv  I s have ?out by(rule WT_gpv_inline1)
  moreover have "?cont TYPE('ret × 's)" (is "input_. _  _  ?case' input")
  proof(rule ballI disjI2)+
    fix input
    assume "input  responses_ℐ ℐ' out"
    with Inr  ⊢g gpv  I s have "ℐ' ⊢g callee' input "
      and "x s'. (x, s')  results_gpv ℐ' (callee' input)   ⊢g rpv' x   I s'"
      by(blast dest: WT_gpv_inline1)+
    then show "?case' input" by(subst c)(auto 4 5)
  qed
  ultimately show "?case TYPE('ret × 's)" ..
qed

end

lemma WT_gpv_inline':
  assumes "s x. x  outs_ℐ   results_gpv ℐ' (callee s x)  responses_ℐ  x × UNIV"
    and "x s. x  outs_ℐ   ℐ' ⊢g callee s x "
    and " ⊢g gpv "
  shows "ℐ' ⊢g inline callee gpv s "
proof -
  interpret raw_converter_invariant  ℐ' callee "λ_. True" 
    using assms by(unfold_locales)auto
  show ?thesis by(rule WT_gpv_inline_invar)(use assms in auto)
qed

lemma results_gpv_sub_gvps: "gpv'  sub_gpvs  gpv  results_gpv  gpv'  results_gpv  gpv"
  by(induction rule: sub_gpvs.induct)(auto intro: results_gpv.IO)

lemma in_results_gpv_sub_gvps: " x  results_gpv  gpv'; gpv'  sub_gpvs  gpv   x  results_gpv  gpv"
  using results_gpv_sub_gvps[of gpv'  gpv] by blast

context raw_converter_invariant begin
lemma results_gpv_inline_aux:
  assumes "(x, s')  results_gpv ℐ' (inline_aux callee y)"
  shows " y = Inl (gpv, s);  ⊢g gpv ; I s   x  results_gpv  gpv  I s'"
    and " y = Inr (rpv, callee'); (z, s')  results_gpv ℐ' callee'.  ⊢g rpv z   I s' 
     (z, s'')  results_gpv ℐ' callee'. x  results_gpv  (rpv z)  I s''  I s'"
  using assms
proof(induction gvp'"inline_aux callee y" arbitrary: y gpv s rpv callee')
  case Pure case 1
  with Pure show ?case
    by(auto simp add: inline_aux.sel split: sum.split_asm dest: inline1_Inl_results_gpv)
next
  case Pure case 2
  with Pure show ?case
    by(clarsimp simp add: inline_aux.sel split: sum.split_asm)
      (fastforce split: generat.split_asm dest: inline1_Inl_results_gpv intro: results_gpv.Pure)+
next
  case (IO out c input) case 1
  with IO(1) obtain rpv rpv' where inline1: "Inr (out, rpv, rpv')  set_spmf (inline1 callee gpv s)"
    and c: "c = (λinput. inline_aux callee (Inr (rpv', rpv input)))"
    by(auto simp add: inline_aux.sel split: sum.split_asm)
  from inline1[THEN inline1_in_sub_gpvs, OF _ input  responses_ℐ ℐ' out _ I s]  ⊢g gpv 
  have "(z, s')results_gpv ℐ' (rpv input).  ⊢g rpv' z   I s'"
    by(auto intro: WT_sub_gpvsD)
  from IO(5)[unfolded c, OF refl refl this] obtain input' s'' 
    where input': "(input', s'')  results_gpv ℐ' (rpv input)" 
      and x: "x  results_gpv  (rpv' input')" and s'': "I s''" "I s'"
    by auto
  from inline1[THEN inline1_in_sub_gpvs, OF input' input  responses_ℐ ℐ' out  ⊢g gpv  I s] s'' x
  show ?case by(auto intro: in_results_gpv_sub_gvps)
next
  case (IO out c input) case 2
  from IO(1) "2"(1) consider (Pure) input' s'' rpv' rpv''
    where "Pure (input', s'')  set_spmf (the_gpv callee')" "Inr (out, rpv', rpv'')  set_spmf (inline1 callee (rpv input') s'')"
      "c = (λinput. inline_aux callee (Inr (rpv'', rpv' input)))"
    | (Cont) rpv' where "IO out rpv'  set_spmf (the_gpv callee')" "c = (λinput. inline_aux callee (Inr (rpv, rpv' input)))"
    by(auto simp add: inline_aux.sel split: sum.split_asm; rename_tac generat; case_tac generat; clarsimp)
  then show ?case
  proof cases
    case Pure
    have res: "(input', s'')  results_gpv ℐ' callee'" using Pure(1) by(rule results_gpv.Pure)
    with 2 have WT: " ⊢g rpv input' " "I s''" by auto
    have "(z, s')results_gpv ℐ' (rpv' input).  ⊢g rpv'' z   I s'"
      using inline1_in_sub_gpvs[OF Pure(2) _ input  _ WT] WT by(auto intro: WT_sub_gpvsD)
    from IO(5)[unfolded Pure(3), OF refl refl this] obtain z s'''
      where z: "(z, s''')  results_gpv ℐ' (rpv' input)"
        and x: "x  results_gpv  (rpv'' z)" and s': "I s'''" "I s'" by auto
    have "x  results_gpv  (rpv input')" using x inline1_in_sub_gpvs[OF Pure(2) z input  _ WT]
      by(auto intro: in_results_gpv_sub_gvps)
    then show ?thesis using res WT s' by auto
  next
    case Cont
    have "(z, s')results_gpv ℐ' (rpv' input).  ⊢g rpv z   I s'" 
      using Cont 2 input  responses_ℐ ℐ' out by(auto intro: results_gpv.IO)
    from IO(5)[unfolded Cont, OF refl refl this] obtain z s'' 
      where "(z, s'')  results_gpv ℐ' (rpv' input)" "x  results_gpv  (rpv z)" "I s''" "I s'" by auto
    then show ?thesis using Cont(1) input  _ by(auto intro: results_gpv.IO)
  qed
qed

lemma results_gpv_inline: 
  "(x, s')  results_gpv ℐ' (inline callee gpv s);  ⊢g gpv ; I s  x  results_gpv  gpv  I s'"
  unfolding inline_def by(rule results_gpv_inline_aux(1)[OF _ refl])

end

lemma inline_map_gpv:
  "inline callee (map_gpv f g gpv) s = map_gpv (apfst f) id (inline (λs x. callee s (g x)) gpv s)"
  unfolding apfst_def
  by(rule inline_parametric
      [where S="BNF_Def.Grp UNIV id" and C="BNF_Def.Grp UNIV g" and C'="BNF_Def.Grp UNIV id" and A="BNF_Def.Grp UNIV f",
        THEN rel_funD, THEN rel_funD, THEN rel_funD,
        unfolded gpv.rel_Grp prod.rel_Grp, simplified, folded eq_alt, unfolded Grp_def, simplified])
    (auto simp add: rel_fun_def relator_eq)

subsection ‹Running GPVs›

type_synonym ('call, 'ret, 's) callee = "'s  'call  ('ret × 's) spmf"

context fixes callee :: "('call, 'ret, 's) callee" notes [[function_internals]] begin

partial_function (spmf) exec_gpv :: "('a, 'call, 'ret) gpv  's  ('a × 's) spmf"
where
  "exec_gpv c s =
   the_gpv c 
     case_generat (λx. return_spmf (x, s))
     (λout c. callee s out  (λ(x, y). exec_gpv (c x) y))"

abbreviation run_gpv :: "('a, 'call, 'ret) gpv  's  'a spmf"
where "run_gpv gpv s  map_spmf fst (exec_gpv gpv s)"

lemma exec_gpv_fixp_induct [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λf. P (λc s. f (c, s)))"
  and "P (λ_ _. return_pmf None)"
  and "exec_gpv. P exec_gpv  
     P (λc s. the_gpv c  case_generat (λx. return_spmf (x, s)) (λout c. callee s out  (λ(x, y). exec_gpv (c x) y)))"
  shows "P exec_gpv"
using assms(1)
by(rule exec_gpv.fixp_induct[unfolded curry_conv[abs_def]])(simp_all add: assms(2-))

lemma exec_gpv_fixp_induct_strong [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λf. P (λc s. f (c, s)))"
  and "P (λ_ _. return_pmf None)"
  and "exec_gpv'.  c s. ord_spmf (=) (exec_gpv' c s) (exec_gpv c s); P exec_gpv' 
     P (λc s. the_gpv c  case_generat (λx. return_spmf (x, s)) (λout c. callee s out  (λ(x, y). exec_gpv' (c x) y)))"
  shows "P exec_gpv"
using assms
by(rule spmf.fixp_strong_induct_uc[where P="λf. P (curry f)" and U=case_prod and C=curry, OF exec_gpv.mono exec_gpv_def, simplified curry_case_prod, simplified curry_conv[abs_def] fun_ord_def split_paired_All prod.case case_prod_eta, OF refl]) blast

lemma exec_gpv_fixp_induct_strong2 [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λf. P (λc s. f (c, s)))"
  and "P (λ_ _. return_pmf None)"
  and "exec_gpv'.
     c s. ord_spmf (=) (exec_gpv' c s) (exec_gpv c s); 
      c s. ord_spmf (=) (exec_gpv' c s) (the_gpv c  case_generat (λx. return_spmf (x, s)) (λout c. callee s out  (λ(x, y). exec_gpv' (c x) y)));
      P exec_gpv' 
     P (λc s. the_gpv c  case_generat (λx. return_spmf (x, s)) (λout c. callee s out  (λ(x, y). exec_gpv' (c x) y)))"
  shows "P exec_gpv"
using assms
by(rule spmf.fixp_induct_strong2_uc[where P="λf. P (curry f)" and U=case_prod and C=curry, OF exec_gpv.mono exec_gpv_def, simplified curry_case_prod, simplified curry_conv[abs_def] fun_ord_def split_paired_All prod.case case_prod_eta, OF refl]) blast+

end

lemma exec_gpv_conv_inline1:
  "exec_gpv callee gpv s = map_spmf projl (inline1 (λs c. lift_spmf (callee s c) :: (_, unit, unit) gpv) gpv s)"
by(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf exec_gpv.mono inline1.mono exec_gpv_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  (auto simp add: map_spmf_bind_spmf o_def spmf.map_comp bind_map_spmf split_def intro!: bind_spmf_cong[OF refl] split: generat.split)

lemma exec_gpv_simps:
  "exec_gpv callee gpv s =
   the_gpv gpv 
     case_generat (λx. return_spmf (x, s))
     (λout rpv. callee s out  (λ(x, y). exec_gpv callee (rpv x) y))"
by(fact exec_gpv.simps)

lemma exec_gpv_lift_spmf [simp]:
  "exec_gpv callee (lift_spmf p) s = bind_spmf p (λx. return_spmf (x, s))"
by(simp add: exec_gpv_conv_inline1 spmf.map_comp o_def map_spmf_conv_bind_spmf)

lemma exec_gpv_Done [simp]: "exec_gpv callee (Done x) s = return_spmf (x, s)"
by(simp add: exec_gpv_conv_inline1)

lemma exec_gpv_Fail [simp]: "exec_gpv callee Fail s = return_pmf None"
by(simp add: exec_gpv_conv_inline1)

lemma if_distrib_exec_gpv [if_distribs]:
  "exec_gpv callee (if b then x else y) s = (if b then exec_gpv callee x s else exec_gpv callee y s)"
by simp

lemmas exec_gpv_fixp_parallel_induct [case_names adm bottom step] =
  parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf exec_gpv.mono exec_gpv.mono exec_gpv_def exec_gpv_def, unfolded lub_spmf_empty]

context includes lifting_syntax begin

lemma exec_gpv_parametric':
  "((S ===> CALL ===> rel_spmf (rel_prod R S)) ===> rel_gpv'' A CALL R ===> S ===> rel_spmf (rel_prod A S))
  exec_gpv exec_gpv"
apply(rule rel_funI)+
apply(unfold spmf_rel_map exec_gpv_conv_inline1)
apply(rule rel_spmf_mono_strong)
 apply(erule inline1_parametric'[THEN rel_funD, THEN rel_funD, THEN rel_funD, rotated])
  prefer 3
  apply(drule in_set_inline1_lift_spmf1)+
  apply fastforce
 subgoal by simp
subgoal premises [transfer_rule]
  supply lift_spmf_parametric'[transfer_rule] by transfer_prover
done

lemma exec_gpv_parametric [transfer_rule]:
  "((S ===> CALL ===> rel_spmf (rel_prod ((=) :: 'ret  _) S)) ===> rel_gpv A CALL ===> S ===> rel_spmf (rel_prod A S))
  exec_gpv exec_gpv"
unfolding rel_gpv_conv_rel_gpv'' by(rule exec_gpv_parametric')

end

lemma exec_gpv_bind: "exec_gpv callee (c  f) s = exec_gpv callee c s  (λ(x, s')  exec_gpv callee (f x) s')"
by(auto simp add: exec_gpv_conv_inline1 inline1_bind_gpv map_spmf_bind_spmf o_def bind_map_spmf intro!: bind_spmf_cong[OF refl] dest: in_set_inline1_lift_spmf1)

lemma exec_gpv_map_gpv_id:
  "exec_gpv oracle (map_gpv f id gpv) σ = map_spmf (apfst f) (exec_gpv oracle gpv σ)"
proof(rule sym)
  define gpv' where "gpv' = map_gpv f id gpv"
  have [transfer_rule]: "rel_gpv (λx y. y = f x) (=) gpv gpv'"
    unfolding gpv'_def by(simp add: gpv.rel_map gpv.rel_refl)
  have "rel_spmf (rel_prod (λx y. y = f x) (=)) (exec_gpv oracle gpv σ) (exec_gpv oracle gpv' σ)"
    by transfer_prover
  thus "map_spmf (apfst f) (exec_gpv oracle gpv σ) = exec_gpv oracle (map_gpv f id gpv) σ"
    unfolding spmf_rel_eq[symmetric] gpv'_def spmf_rel_map by(rule rel_spmf_mono) clarsimp
qed

lemma exec_gpv_Pause [simp]:
  "exec_gpv callee (Pause out f) s = callee s out  (λ(x, s'). exec_gpv callee (f x) s')"
by(simp add: inline1_Pause map_spmf_bind_spmf bind_map_spmf o_def exec_gpv_conv_inline1 split_def)

lemma exec_gpv_bind_lift_spmf:
  "exec_gpv callee (bind_gpv (lift_spmf p) f) s = bind_spmf p (λx. exec_gpv callee (f x) s)"
by(simp add: exec_gpv_bind)

lemma exec_gpv_bind_option [simp]:
  "exec_gpv oracle (monad.bind_option Fail x f) s = monad.bind_option (return_pmf None) x (λa. exec_gpv oracle (f a) s)"
by(cases x) simp_all

lemma pred_spmf_exec_gpv:
  ― ‹We don't get an equivalence here because states are threaded through in @{const exec_gpv}.›
  " pred_gpv A C gpv; pred_fun S (pred_fun C (pred_spmf (pred_prod (λ_. True) S))) callee; S s 
   pred_spmf (pred_prod A S) (exec_gpv callee gpv s)"
using exec_gpv_parametric[of "eq_onp S" "eq_onp C" "eq_onp A", folded eq_onp_True]
apply(unfold prod.rel_eq_onp option.rel_eq_onp pmf.rel_eq_onp gpv.rel_eq_onp)
apply(drule rel_funD[where x=callee and y=callee])
 subgoal
   apply(rule rel_fun_mono[where X="eq_onp S"])
     apply(rule rel_fun_eq_onpI)
     apply(unfold eq_onp_same_args)
     apply assumption
    apply simp
   apply(erule rel_fun_eq_onpI)
   done
apply(auto dest!: rel_funD simp add: eq_onp_def)
done

lemma exec_gpv_inline:
  fixes callee :: "('c, 'r, 's) callee"
  and gpv :: "'s'  'c'  ('r' × 's', 'c, 'r) gpv"
  shows "exec_gpv callee (inline gpv c' s') s =
    map_spmf (λ(x, s', s). ((x, s'), s)) (exec_gpv (λ(s', s) y. map_spmf (λ((x, s'), s). (x, s', s)) (exec_gpv callee (gpv s' y) s)) c' (s', s))"
    (is "?lhs = ?rhs")
proof -
  have "?lhs = map_spmf projl (map_spmf (map_sum (λ(x, s2, y). ((x, s2), y))
        (λ(x, rpv'' :: ('r × 's, unit, unit) rpv, rpv', rpv). (x, rpv'', λr1. bind_gpv (rpv' r1) (λ(r2, y). inline gpv (rpv r2) y))))
      (inline2 (λs c. lift_spmf (callee s c)) gpv c' s' s))"
    unfolding exec_gpv_conv_inline1 by(simp add: inline1_inline_conv_inline2)
  also have " = map_spmf (λ(x, s', s). ((x, s'), s)) (map_spmf projl (map_spmf (map_sum id
        (λ(x, rpv'' :: ('r × 's, unit, unit) rpv, rpv', rpv). (x, λr. bind_gpv (rpv'' r) (λ(r1, s1). map_gpv (λ((r2, s2), s1). (r2, s2, s1)) id (inline (λs c. lift_spmf (callee s c)) (rpv' r1) s1)), rpv)))
      (inline2 (λs c. lift_spmf (callee s c)) gpv c' s' s)))"
   unfolding spmf.map_comp by(rule map_spmf_cong[OF refl])(auto dest!: in_set_inline2_lift_spmf1)
  also have " = ?rhs" unfolding exec_gpv_conv_inline1
    by(subst inline1_inline_conv_inline2'[symmetric])(simp add: spmf.map_comp split_def inline_lift_spmf1 map_lift_spmf)
  finally show ?thesis .
qed

lemma ord_spmf_exec_gpv:
  assumes callee: "s x. ord_spmf (=) (callee1 s x) (callee2 s x)"
  shows "ord_spmf (=) (exec_gpv callee1 gpv s) (exec_gpv callee2 gpv s)"
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_parallel_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
next
  case (step exec_gpv1 exec_gpv2)
  show ?case using step.prems
    by(clarsimp intro!: ord_spmf_bind_reflI ord_spmf_bindI[OF assms] step.IH split!: generat.split)
qed

context fixes callee :: "('call, 'ret, 's) callee" notes [[function_internals]] begin

partial_function (spmf) execp_resumption :: "('a, 'call, 'ret) resumption  's  ('a × 's) spmf"
where
  "execp_resumption r s = (case r of resumption.Done x  return_pmf (map_option (λa. (a, s)) x)
      | resumption.Pause out c  bind_spmf (callee s out) (λ(input, s'). execp_resumption (c input) s'))"

simps_of_case execp_resumption_simps [simp]: execp_resumption.simps

lemma execp_resumption_ABORT [simp]: "execp_resumption ABORT s = return_pmf None"
by(simp add: ABORT_def)

lemma execp_resumption_DONE [simp]: "execp_resumption (DONE x) s = return_spmf (x, s)"
by(simp add: DONE_def)

lemma exec_gpv_lift_resumption: "exec_gpv callee (lift_resumption r) s = execp_resumption r s"
proof(induction arbitrary: r s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf exec_gpv.mono execp_resumption.mono exec_gpv_def execp_resumption_def, case_names adm bot step])
  case adm show ?case by(simp)
  case bot thus ?case by simp
  case (step exec_gpv' execp_resumption')
  show ?case
    by(auto split: resumption.split option.split simp add: lift_resumption.sel intro: bind_spmf_cong step)
qed

lemma mcont2mcont_execp_resumption [THEN spmf.mcont2mcont, cont_intro, simp]:
  shows mcont_execp_resumption:
  "mcont resumption_lub resumption_ord lub_spmf (ord_spmf (=)) (λr. execp_resumption r s)"
proof -
  have "mcont (prod_lub resumption_lub the_Sup) (rel_prod resumption_ord (=)) lub_spmf (ord_spmf (=)) (case_prod execp_resumption)"
  proof(rule ccpo.fixp_preserves_mcont2[OF ccpo_spmf execp_resumption.mono execp_resumption_def])
    fix execp_resumption' :: "('b, 'call, 'ret) resumption  's  ('b × 's) spmf"
    assume *: "mcont (prod_lub resumption_lub the_Sup) (rel_prod resumption_ord (=)) lub_spmf (ord_spmf (=)) (λ(r, s). execp_resumption' r s)"
    have [THEN spmf.mcont2mcont, cont_intro, simp]: "mcont resumption_lub resumption_ord lub_spmf (ord_spmf (=)) (λr. execp_resumption' r s)" 
      for s using * by simp
    have "mcont resumption_lub resumption_ord lub_spmf (ord_spmf (=))
      (λr. case r of resumption.Done x  return_pmf (map_option (λa. (a, s)) x)
           | resumption.Pause out c  bind_spmf (callee s out) (λ(input, s'). execp_resumption' (c input) s'))"
      for s by(rule mcont_case_resumption)(auto simp add: ccpo_spmf intro!: mcont_bind_spmf)
    thus "mcont (prod_lub resumption_lub the_Sup) (rel_prod resumption_ord (=)) lub_spmf (ord_spmf (=))
          (λ(r, s). case r of resumption.Done x  return_pmf (map_option (λa. (a, s)) x)
              | resumption.Pause out c  bind_spmf (callee s out) (λ(input, s'). execp_resumption' (c input) s'))"
      by simp
  qed
  thus ?thesis by auto
qed


lemma execp_resumption_bind [simp]:
  "execp_resumption (r  f) s = execp_resumption r s  (λ(x, s'). execp_resumption (f x) s')"
by(simp add: exec_gpv_lift_resumption[symmetric] lift_resumption_bind exec_gpv_bind)

lemma pred_spmf_execp_resumption:
  "A.  pred_resumption A C r; pred_fun S (pred_fun C (pred_spmf (pred_prod (λ_. True) S))) callee; S s 
   pred_spmf (pred_prod A S) (execp_resumption r s)"
unfolding exec_gpv_lift_resumption[symmetric]
by(rule pred_spmf_exec_gpv) simp_all

end

inductive WT_callee :: "('call, 'ret) ('call  ('ret × 's) spmf)  bool" ("(_) ⊢c/ (_) " [100, 0] 99)
  for  callee
where
  WT_callee:
  " call ret s.  call  outs_ℐ ; (ret, s)  set_spmf (callee call)   ret  responses_ℐ  call 
    ⊢c callee "

lemmas WT_calleeI = WT_callee
hide_fact WT_callee

lemma WT_calleeD: "  ⊢c callee ; (ret, s)  set_spmf (callee out); out  outs_ℐ    ret  responses_ℐ  out"
by(rule WT_callee.cases)

lemma WT_callee_full [intro!, simp]: "ℐ_full ⊢c callee "
by(rule WT_calleeI) simp

lemma WT_callee_parametric [transfer_rule]:
  includes lifting_syntax 
  assumes [transfer_rule]: "bi_unique R"
  shows "(rel_ℐ C R ===> (C ===> rel_spmf (rel_prod R S)) ===> (=)) WT_callee WT_callee"
proof -
  have *: "WT_callee = (λ callee. call outs_ℐ . (ret, s)  set_spmf (callee call). ret  responses_ℐ  call)"
    unfolding WT_callee.simps by blast
  show ?thesis unfolding * by transfer_prover
qed

locale callee_invariant_on_base =
  fixes callee :: "'s  'a  ('b × 's) spmf"
  and I :: "'s  bool"
  and  :: "('a, 'b) ℐ"

locale callee_invariant_on = callee_invariant_on_base callee I 
  for callee :: "'s  'a  ('b × 's) spmf"
  and I :: "'s  bool"
  and  :: "('a, 'b) ℐ"
  +
  assumes callee_invariant: "s x y s'.  (y, s')  set_spmf (callee s x); I s; x  outs_ℐ    I s'"
  and WT_callee: "s. I s   ⊢c callee s "
begin

lemma callee_invariant': " (y, s')  set_spmf (callee s x); I s; x  outs_ℐ    I s'  y  responses_ℐ  x"
by(auto dest: WT_calleeD[OF WT_callee] callee_invariant)

lemma exec_gpv_invariant':
  " I s;  ⊢g gpv    set_spmf (exec_gpv callee gpv s)  {(x, s'). I s'}"
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct)
  case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
  case bottom show ?case by simp
  case step show ?case using step.prems
    by(auto simp add: bind_UNION intro!: UN_least step.IH del: subsetI split: generat.split dest!: callee_invariant' elim: WT_gpvD)
qed

lemma exec_gpv_invariant:
  " (x, s')  set_spmf (exec_gpv callee gpv s); I s;  ⊢g gpv    I s'"
by(drule exec_gpv_invariant') blast+

lemma interaction_bounded_by_exec_gpv_count':
  fixes count
  assumes bound: "interaction_bounded_by consider gpv n"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); I s; consider x; x  outs_ℐ    count s'  eSuc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ consider x; x  outs_ℐ    count s'  count s"
  and WT: " ⊢g gpv "
  and I: "I s"
  shows "set_spmf (exec_gpv callee gpv s)  {(x, s'). count s'  n + count s}"
using bound I WT
proof(induction arbitrary: gpv s n rule: exec_gpv_fixp_induct)
  case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
  case bottom show ?case by simp
  case (step exec_gpv')
  have "set_spmf (exec_gpv' (c input) s')  {(x, s''). count s''  n + count s}"
    if out: "IO out c  set_spmf (the_gpv gpv)"
    and input: "(input, s')  set_spmf (callee s out)"
    and X: "out  outs_ℐ "
    for out c input s'
  proof(cases "consider out")
    case True
    with step.prems out have "n > 0"
      and bound': "interaction_bounded_by consider (c input) (n - 1)"
      by(auto dest: interaction_bounded_by_contD)
    note bound'
    moreover from input I s X have "I s'" by(rule callee_invariant)
    moreover have " ⊢g c input " using step.prems(3) out WT_calleeD[OF WT_callee input]
      by(rule WT_gpvD)(rule step.prems X)+
    ultimately have "set_spmf (exec_gpv' (c input) s')  {(x, s''). count s''  n - 1 + count s'}"      
      by(rule step.IH)
    also have "  {(x, s''). count s''  n + count s}" using n > 0 count[OF input I s True X]     
      by(cases n rule: co.enat.exhaust)(auto, metis add_left_mono_trans eSuc_plus iadd_Suc_right)
    finally show ?thesis .
  next
    case False
    from step.prems out this have bound': "interaction_bounded_by consider (c input) n"
      by(auto dest: interaction_bounded_by_contD_ignore)
    from input I s X have "I s'" by(rule callee_invariant)
    note bound'
    moreover from input I s X have "I s'" by(rule callee_invariant)
    moreover have " ⊢g c input " using step.prems(3) out WT_calleeD[OF WT_callee input]
      by(rule WT_gpvD)(rule step.prems X)+
    ultimately have "set_spmf (exec_gpv' (c input) s')  {(x, s''). count s''  n + count s'}"
      by(rule step.IH)
    also have "  {(x, s''). count s''  n + count s}"
      using ignore[OF input I s False X] by(auto elim: order_trans)
    finally show ?thesis .
  qed
  then show ?case using step.prems(3)
    by(auto 4 3 simp add: bind_UNION del: subsetI intro!: UN_least split: generat.split dest: WT_gpvD)
qed

lemma interaction_bounded_by_exec_gpv_count:
  fixes count
  assumes bound: "interaction_bounded_by consider gpv n"
  and xs': "(x, s')  set_spmf (exec_gpv callee gpv s)"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); I s; consider x; x  outs_ℐ    count s'  eSuc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ consider x; x  outs_ℐ    count s'  count s"
  and WT: " ⊢g gpv "
  and I: "I s"
  shows "count s'  n + count s"
using bound count ignore WT I 
by(rule interaction_bounded_by_exec_gpv_count'[THEN subsetD, OF _ _ _ _ _ xs', unfolded mem_Collect_eq prod.case])

lemma interaction_bounded_by'_exec_gpv_count:
  fixes count
  assumes bound: "interaction_bounded_by' consider gpv n"
  and xs': "(x, s')  set_spmf (exec_gpv callee gpv s)"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); I s; consider x; x  outs_ℐ    count s'  Suc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ consider x; x  outs_ℐ    count s'  count s"
  and outs: " ⊢g gpv "
  and I: "I s"
  shows "count s'  n + count s"
using interaction_bounded_by_exec_gpv_count[OF bound xs', of count] count ignore outs I
by(simp add: eSuc_enat)

lemma pred_spmf_calleeI: " I s; x  outs_ℐ    pred_spmf (pred_prod (λ_. True) I) (callee s x)"
by(auto simp add: pred_spmf_def dest: callee_invariant)

lemma lossless_exec_gpv:
  assumes gpv: "lossless_gpv  gpv"
  and callee: "s out.  out  outs_ℐ ; I s   lossless_spmf (callee s out)"
  and WT_gpv: " ⊢g gpv "
  and I: "I s"
  shows "lossless_spmf (exec_gpv callee gpv s)"
using gpv WT_gpv I
proof(induction arbitrary: s rule: lossless_WT_gpv_induct)
  case (lossless_gpv gpv)
  show ?case using lossless_gpv.hyps lossless_gpv.prems
    by(subst exec_gpv.simps)(fastforce split: generat.split simp add: callee intro!: lossless_gpv.IH intro: WT_calleeD[OF WT_callee] elim!: callee_invariant)
qed

lemma in_set_spmf_exec_gpv_into_results_gpv:
  assumes *: "(x, s')  set_spmf (exec_gpv callee gpv s)"
  and WT_gpv : " ⊢g gpv "
  and I: "I s"
  shows "x  results_gpv  gpv"
proof -
  have "set_spmf (exec_gpv callee gpv s)  results_gpv  gpv × UNIV"
    using WT_gpv I
  proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct)
    { case adm show ?case by(intro cont_intro ccpo_class.admissible_leI) }
    { case bottom show ?case by simp }
    case (step exec_gpv')
    { fix out c ret s'
      assume IO: "IO out c  set_spmf (the_gpv gpv)"
        and ret: "(ret, s')  set_spmf (callee s out)"
      from step.prems(1) IO have "out  outs_ℐ " by(rule WT_gpvD)
      with WT_callee[OF I s] ret have "ret  responses_ℐ  out" by(rule WT_calleeD)
      with step.prems(1) IO have " ⊢g c ret " by(rule WT_gpvD)
      moreover from ret I s out  outs_ℐ  have "I s'" by(rule callee_invariant)
      ultimately have "set_spmf (exec_gpv' (c ret) s')  results_gpv  (c ret) × UNIV"
        by(rule step.IH)
      also have "  results_gpv  gpv × UNIV" using IO ret  _
        by(auto intro: results_gpv.IO)
      finally have "set_spmf (exec_gpv' (c ret) s')  results_gpv  gpv × UNIV" . }
    then show ?case using step.prems
      by(auto simp add: bind_UNION intro!: UN_least del: subsetI split: generat.split intro: results_gpv.Pure)
  qed
  thus "x  results_gpv  gpv" using * by blast+
qed

end

lemma callee_invariant_on_alt_def:
  "callee_invariant_on = (λcallee I .
    (s  Collect I. x  outs_ℐ . (y, s')  set_spmf (callee s x). I s') 
    (s  Collect I.  ⊢c callee s ))"
unfolding callee_invariant_on_def by blast

lemma callee_invariant_on_parametric [transfer_rule]: includes lifting_syntax
  assumes [transfer_rule]: "bi_unique R" "bi_total S"
  shows "((S ===> C ===> rel_spmf (rel_prod R S)) ===> (S ===> (=)) ===> rel_ℐ C R ===> (=))
    callee_invariant_on callee_invariant_on"
unfolding callee_invariant_on_alt_def by transfer_prover

lemma callee_invariant_on_cong:
  " I = I'; outs_ℐ  = outs_ℐ ℐ'; 
    s x.  I' s; x  outs_ℐ ℐ'   set_spmf (callee s x)  responses_ℐ  x × Collect I'  set_spmf (callee' s x)  responses_ℐ ℐ' x × Collect I' 
   callee_invariant_on callee I  = callee_invariant_on callee' I' ℐ'"
unfolding callee_invariant_on_def WT_callee.simps
by safe((erule meta_allE)+, (erule (1) meta_impE)+, force)+ 

abbreviation callee_invariant :: "('s  'a  ('b × 's) spmf)  ('s  bool)  bool"
where "callee_invariant callee I  callee_invariant_on callee I ℐ_full"

interpretation oi_True: callee_invariant_on callee "λ_. True" ℐ_full for callee
by unfold_locales (simp_all)

lemma callee_invariant_on_return_spmf [simp]:
  "callee_invariant_on (λs x. return_spmf (f s x)) I   (s. xouts_ℐ . I s  I (snd (f s x))  fst (f s x)  responses_ℐ  x)"
by(auto simp add: callee_invariant_on_def split_pairs WT_callee.simps)

lemma callee_invariant_return_spmf [simp]:
  "callee_invariant (λs x. return_spmf (f s x)) I  (s x. I s  I (snd (f s x)))"
by(auto simp add: callee_invariant_on_def split_pairs)

lemma callee_invariant_restrict_relp:
  includes lifting_syntax
  assumes "(S ===> C ===> rel_spmf (rel_prod R S)) callee1 callee2"
  and "callee_invariant callee1 I1"
  and "callee_invariant callee2 I2"
  shows "((S  I1  I2) ===> C ===> rel_spmf (rel_prod R (S  I1  I2))) callee1 callee2"
proof -
  interpret ci1: callee_invariant_on callee1 I1 ℐ_full by fact
  interpret ci2: callee_invariant_on callee2 I2 ℐ_full by fact
  show ?thesis using assms(1)
    by(intro rel_funI)(auto simp add: restrict_rel_prod2 intro!: rel_spmf_restrict_relpI intro: ci1.pred_spmf_calleeI ci2.pred_spmf_calleeI dest: rel_funD rel_setD1 rel_setD2)
qed

lemma callee_invariant_on_True [simp]: "callee_invariant_on callee (λ_. True)   (s.  ⊢c callee s )"
by(simp add: callee_invariant_on_def)

lemma lossless_exec_gpv:
  " lossless_gpv  gpv;  s out. out  outs_ℐ   lossless_spmf (callee s out);
      ⊢g gpv ; s.  ⊢c callee s  
   lossless_spmf (exec_gpv callee gpv s)"
by(rule callee_invariant_on.lossless_exec_gpv; simp)

lemma in_set_spmf_exec_gpv_into_results'_gpv:
  assumes *: "(x, s')  set_spmf (exec_gpv callee gpv s)"
  shows "x  results'_gpv gpv"
using oi_True.in_set_spmf_exec_gpv_into_results_gpv[OF *] by(simp add: results_gpv_ℐ_full)


context fixes  :: "('out, 'in) ℐ" begin

primcorec restrict_gpv :: "('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv"
where
  "restrict_gpv gpv = GPV (
  map_pmf (case_option None (case_generat (Some  Pure) 
       (λout c. if out  outs_ℐ  then Some (IO out (λinput. if input  responses_ℐ  out then restrict_gpv (c input) else Fail))
          else None)))
    (the_gpv gpv))" 

lemma restrict_gpv_Done [simp]: "restrict_gpv (Done x) = Done x"
by(rule gpv.expand)(simp)

lemma restrict_gpv_Fail [simp]: "restrict_gpv Fail = Fail"
by(rule gpv.expand)(simp)

lemma restrict_gpv_Pause [simp]: "restrict_gpv (Pause out c) = (if out  outs_ℐ  then Pause out (λinput. if input  responses_ℐ  out then restrict_gpv (c input) else Fail) else Fail)"
by(rule gpv.expand)(simp)

lemma restrict_gpv_bind [simp]: "restrict_gpv (bind_gpv gpv f) = bind_gpv (restrict_gpv gpv) (λx. restrict_gpv (f x))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(auto 4 3 simp del: bind_gpv_sel' simp add: bind_gpv.sel bind_spmf_def pmf.rel_map bind_map_pmf rel_fun_def intro!: rel_pmf_bind_reflI rel_pmf_reflI split!: option.split generat.split split: if_split_asm)
done

lemma WT_restrict_gpv [simp]: " ⊢g restrict_gpv gpv "
apply(coinduction arbitrary: gpv)
apply(clarsimp split: option.split_asm)
apply(split generat.split_asm; auto split: if_split_asm)
done

lemma exec_gpv_restrict_gpv:
  assumes " ⊢g gpv " and WT_callee: "s.  ⊢c callee s "
  shows "exec_gpv callee (restrict_gpv gpv) s = exec_gpv callee gpv s"
using assms(1)
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step exec_gpv') show ?case 
    by(auto 4 3 simp add: bind_spmf_def bind_map_pmf in_set_spmf[symmetric] WT_gpv_OutD[OF step.prems] WT_calleeD[OF WT_callee] intro!: bind_pmf_cong[OF refl] step.IH split!: option.split generat.split intro: WT_gpv_ContD[OF step.prems])
qed

lemma in_outs'_restrict_gpvD: "x  outs'_gpv (restrict_gpv gpv)  x  outs_ℐ "
apply(induction gpv'"restrict_gpv gpv" arbitrary: gpv rule: outs'_gpv_induct)
apply(clarsimp split: option.split_asm; split generat.split_asm; clarsimp split: if_split_asm)+
done

lemma outs'_restrict_gpv: "outs'_gpv (restrict_gpv gpv)  outs_ℐ " by(blast intro: in_outs'_restrict_gpvD)

lemma lossless_restrict_gpvI: " lossless_gpv  gpv;  ⊢g gpv    lossless_gpv  (restrict_gpv gpv)"
apply(induction rule: lossless_gpv_induct)
apply(rule lossless_gpvI)
subgoal by(clarsimp simp add: lossless_map_pmf lossless_iff_set_pmf_None in_set_spmf[symmetric] WT_gpv_OutD split: option.split_asm generat.split_asm if_split_asm)
subgoal by(clarsimp split: option.split_asm; split generat.split_asm; force simp add: fun_eq_iff in_set_spmf[symmetric] split: if_split_asm intro: WT_gpv_ContD)
done

lemma lossless_restrict_gpvD: " lossless_gpv  (restrict_gpv gpv);  ⊢g gpv    lossless_gpv  gpv"
proof(induction gpv'"restrict_gpv gpv" arbitrary: gpv rule: lossless_gpv_induct)
  case (lossless_gpv p)
  from lossless_gpv.hyps(4) have p: "p = the_gpv (restrict_gpv gpv)" by(cases "restrict_gpv gpv") simp
  show ?case
  proof(rule lossless_gpvI) 
    from lossless_gpv.hyps(1) show "lossless_spmf (the_gpv gpv)"
      by(auto simp add: p lossless_iff_set_pmf_None intro: rev_image_eqI)

    fix out c input
    assume IO: "IO out c  set_spmf (the_gpv gpv)" and input: "input  responses_ℐ  out"
    from lossless_gpv.prems(1) IO have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    hence "IO out (λinput. if input  responses_ℐ  out then restrict_gpv (c input) else Fail)  set_spmf p" using IO
      by(auto simp add: p in_set_spmf intro: rev_bexI)
    from lossless_gpv.hyps(3)[OF this input, of "c input"] WT_gpvD[OF lossless_gpv.prems IO] input
    show "lossless_gpv  (c input)" by simp
  qed
qed
  
lemma colossless_restrict_gpvD:
  " colossless_gpv  (restrict_gpv gpv);  ⊢g gpv    colossless_gpv  gpv"
proof(coinduction arbitrary: gpv)
  case (colossless_gpv gpv)
  have ?lossless_spmf using colossless_gpv(1)[THEN colossless_gpv_lossless_spmfD]
    by(auto simp add: lossless_iff_set_pmf_None intro: rev_image_eqI)
  moreover have ?continuation
  proof(intro strip disjI1)
    fix out c input
    assume IO: "IO out c  set_spmf (the_gpv gpv)" and input: "input  responses_ℐ  out"
    from colossless_gpv(2) IO have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    hence "IO out (λinput. if input  responses_ℐ  out then restrict_gpv (c input) else Fail)  set_spmf (the_gpv (restrict_gpv gpv))"
      using IO by(auto simp add: in_set_spmf intro: rev_bexI)
    from colossless_gpv_continuationD[OF colossless_gpv(1) this input] input WT_gpv_ContD[OF colossless_gpv(2) IO input]
    show "gpv. c input = gpv  colossless_gpv  (restrict_gpv gpv)   ⊢g gpv " by simp
  qed
  ultimately show ?case ..
qed

lemma colossless_restrict_gpvI:
  " colossless_gpv  gpv;  ⊢g gpv    colossless_gpv  (restrict_gpv gpv)"
proof(coinduction arbitrary: gpv)
  case (colossless_gpv gpv)
  have ?lossless_spmf using colossless_gpv(1)[THEN colossless_gpv_lossless_spmfD]
    by(auto simp add: lossless_iff_set_pmf_None in_set_spmf[symmetric] split: option.split_asm generat.split_asm if_split_asm dest: WT_gpv_OutD[OF colossless_gpv(2)])
  moreover have ?continuation
  proof(intro strip disjI1)
    fix out c input
    assume IO: "IO out c  set_spmf (the_gpv (restrict_gpv gpv))" and input: "input  responses_ℐ  out"
    then obtain c' where out: "out  outs_ℐ "
      and c: "c = (λinput. if input  responses_ℐ  out then restrict_gpv (c' input) else Fail)"
      and IO': "IO out c'  set_spmf (the_gpv gpv)"
      by(clarsimp split: option.split_asm; split generat.split_asm; clarsimp simp add: in_set_spmf split: if_split_asm)
    with input WT_gpv_ContD[OF colossless_gpv(2) IO' input] colossless_gpv_continuationD[OF colossless_gpv(1) IO' input]
    show "gpv. c input = restrict_gpv gpv  colossless_gpv  gpv   ⊢g gpv " by(auto)
  qed
  ultimately show ?case ..
qed

lemma gen_colossless_restrict_gpv [simp]:
  " ⊢g gpv   gen_lossless_gpv b  (restrict_gpv gpv)  gen_lossless_gpv b  gpv"
by(cases b)(auto intro: lossless_restrict_gpvI lossless_restrict_gpvD colossless_restrict_gpvI colossless_restrict_gpvD)

lemma interaction_bound_restrict_gpv:
  "interaction_bound consider (restrict_gpv gpv)  interaction_bound consider gpv"
proof(induction arbitrary: gpv rule: interaction_bound_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step interaction_bound')
  show ?case using step.hyps(1)[of Fail]
    by(fastforce simp add: SUP_UNION set_spmf_def bind_UNION intro: SUP_mono rev_bexI step.IH split: option.split generat.split)
qed

lemma interaction_bounded_by_restrict_gpvI [interaction_bound, simp]:
  "interaction_bounded_by consider gpv n  interaction_bounded_by consider (restrict_gpv gpv) n"
using interaction_bound_restrict_gpv[of "consider" gpv] by(simp add: interaction_bounded_by.simps)

end

lemma restrict_gpv_parametric':
  includes lifting_syntax
  notes [transfer_rule] = the_gpv_parametric' Fail_parametric' corec_gpv_parametric'
  assumes [transfer_rule]: "bi_unique C" "bi_unique R"
  shows "(rel_ℐ C R ===> rel_gpv'' A C R ===> rel_gpv'' A C R) restrict_gpv restrict_gpv"
unfolding restrict_gpv_def by transfer_prover

lemma restrict_gpv_parametric [transfer_rule]: includes lifting_syntax shows 
  "bi_unique C  (rel_ℐ C (=) ===> rel_gpv A C ===> rel_gpv A C) restrict_gpv restrict_gpv"
using restrict_gpv_parametric'[of C "(=)" A]
by(simp add: bi_unique_eq rel_gpv_conv_rel_gpv'')

lemma map_restrict_gpv: "map_gpv f id (restrict_gpv  gpv) = restrict_gpv  (map_gpv f id gpv)"
  for gpv :: "('a, 'out, 'ret) gpv"
using restrict_gpv_parametric[of "BNF_Def.Grp UNIV (id :: 'out  'out)" "BNF_Def.Grp UNIV f", where ?'c='ret]
unfolding gpv.rel_Grp by(simp add: eq_alt[symmetric] rel_ℐ_eq rel_fun_def bi_unique_eq)(simp add: Grp_def)

lemma (in callee_invariant_on) exec_gpv_restrict_gpv_invariant:
  assumes " ⊢g gpv " and "I s"
  shows "exec_gpv callee (restrict_gpv  gpv) s = exec_gpv callee gpv s"
using assms
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step exec_gpv') show ?case using step.prems(2)
    by(auto 4 3 simp add: bind_spmf_def bind_map_pmf in_set_spmf[symmetric] WT_gpv_OutD[OF step.prems(1)] WT_calleeD[OF WT_callee[OF step.prems(2)]] intro!: bind_pmf_cong[OF refl] step.IH split!: option.split generat.split intro: WT_gpv_ContD[OF step.prems(1)] callee_invariant)
qed

lemma in_results_gpv_restrict_gpvD:
  assumes "x  results_gpv  (restrict_gpv ℐ' gpv)"
  shows "x  results_gpv  gpv"
  using assms
  apply(induction gpv'"restrict_gpv ℐ' gpv" arbitrary: gpv)
   apply(clarsimp split: option.split_asm simp add: in_set_spmf[symmetric])
  subgoal for … y by(cases y)(auto intro: results_gpv.intros split: if_split_asm)
  apply(clarsimp split: option.split_asm simp add: in_set_spmf[symmetric])
  subgoal for … y by(cases y)(auto intro: results_gpv.intros split: if_split_asm)
  done

lemma results_gpv_restrict_gpv:
  "results_gpv  (restrict_gpv ℐ' gpv)  results_gpv  gpv"
  by(blast intro: in_results_gpv_restrict_gpvD)

lemma in_results'_gpv_restrict_gpvD:
  "x  results'_gpv (restrict_gpv ℐ' gpv)  x  results'_gpv gpv"
  by(rule in_results_gpv_restrict_gpvD[where= "ℐ_full", unfolded results_gpv_ℐ_full])

primcorec enforce_ℐ_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv" where
  "enforce_ℐ_gpv  gpv = GPV 
    (map_spmf (map_generat id id ((∘) (enforce_ℐ_gpv ))) 
     (map_spmf (λgenerat. case generat of Pure x  Pure x | IO out rpv  IO out (λinput. if input  responses_ℐ  out then rpv input else Fail))
        (enforce_spmf (pred_generat  (λx. x  outs_ℐ ) ) (the_gpv gpv))))"

lemma enforce_ℐ_gpv_Done [simp]: "enforce_ℐ_gpv  (Done x) = Done x"
  by(rule gpv.expand) simp

lemma enforce_ℐ_gpv_Fail [simp]: "enforce_ℐ_gpv  Fail = Fail"
  by(rule gpv.expand) simp

lemma enforce_ℐ_gpv_Pause [simp]:
  "enforce_ℐ_gpv  (Pause out rpv) =
   (if out  outs_ℐ  then Pause out (λinput. if input  responses_ℐ  out then enforce_ℐ_gpv  (rpv input) else Fail) else Fail)"
  by(rule gpv.expand)(simp add: fun_eq_iff)

lemma enforce_ℐ_gpv_lift_spmf [simp]: "enforce_ℐ_gpv  (lift_spmf p) = lift_spmf p"
  by(rule gpv.expand)(simp add: enforce_map_spmf spmf.map_comp o_def)

lemma enforce_ℐ_gpv_bind_gpv [simp]:
  "enforce_ℐ_gpv  (bind_gpv gpv f) = bind_gpv (enforce_ℐ_gpv  gpv) (enforce_ℐ_gpv   f)"
  by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
    (auto 4 3 simp add: bind_gpv.sel spmf_rel_map bind_map_spmf o_def pred_generat_def elim!: generat.set_cases intro!: generat.rel_refl_strong rel_spmf_bind_reflI rel_spmf_reflI rel_funI split!: if_splits generat.split_asm)

lemma enforce_ℐ_gpv_parametric':
  includes lifting_syntax 
  notes [transfer_rule] = corec_gpv_parametric' the_gpv_parametric' Fail_parametric'
  assumes [transfer_rule]: "bi_unique C" "bi_unique R"
  shows "(rel_ℐ C R ===> rel_gpv'' A C R ===> rel_gpv'' A C R) enforce_ℐ_gpv enforce_ℐ_gpv"
  unfolding enforce_ℐ_gpv_def top_fun_def by(transfer_prover)

lemma enforce_ℐ_gpv_parametric [transfer_rule]: includes lifting_syntax shows
  "bi_unique C  (rel_ℐ C (=) ===> rel_gpv A C ===> rel_gpv A C) enforce_ℐ_gpv enforce_ℐ_gpv"
  unfolding rel_gpv_conv_rel_gpv'' by(rule enforce_ℐ_gpv_parametric'[OF _ bi_unique_eq])

lemma WT_enforce_ℐ_gpv [simp]: " ⊢g enforce_ℐ_gpv  gpv "
  by(coinduction arbitrary: gpv)(auto split: generat.split_asm)

context fixes  :: "('out, 'in) ℐ" begin

inductive finite_gpv :: "('a, 'out, 'in) gpv  bool"
where
  finite_gpvI: 
  "(out c input.  IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out   finite_gpv (c input))  finite_gpv gpv"

lemmas finite_gpv_induct[consumes 1, case_names finite_gpv, induct pred] = finite_gpv.induct

lemma finite_gpvD: " finite_gpv gpv; IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out   finite_gpv (c input)"
by(auto elim: finite_gpv.cases)

lemma finite_gpv_Fail [simp]: "finite_gpv Fail"
by(auto intro: finite_gpvI)

lemma finite_gpv_Done [simp]: "finite_gpv (Done x)"
by(auto intro: finite_gpvI)

lemma finite_gpv_Pause [simp]: "finite_gpv (Pause x c)  (input  responses_ℐ  x. finite_gpv (c input))"
by(auto dest: finite_gpvD intro: finite_gpvI)

lemma finite_gpv_lift_spmf [simp]: "finite_gpv (lift_spmf p)"
by(auto intro: finite_gpvI)

lemma finite_gpv_bind [simp]:
  "finite_gpv (gpv  f)  finite_gpv gpv  (xresults_gpv  gpv. finite_gpv (f x))"
  (is "?lhs = ?rhs")
proof(intro iffI conjI ballI; (elim conjE)?)
  show "finite_gpv gpv" if "?lhs" using that
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case finite_gpv
    show ?case
    proof(rule finite_gpvI)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv gpv)"
        and input: "input  responses_ℐ  out"
      have "IO out (λinput. c input  f)  set_spmf (the_gpv (gpv  f))"
        using IO by(auto intro: rev_bexI)
      thus "finite_gpv (c input)" using input by(rule finite_gpv.hyps) simp
    qed
  qed
  show "finite_gpv (f x)" if "x  results_gpv  gpv" ?lhs for x using that
  proof(induction)
    case (Pure gpv)
    show ?case
    proof
      fix out c input
      assume "IO out c  set_spmf (the_gpv (f x))" "input  responses_ℐ  out"
      with Pure have "IO out c  set_spmf (the_gpv (gpv  f))" by(auto intro: rev_bexI)
      with Pure.prems show "finite_gpv (c input)" by(rule finite_gpvD) fact
    qed
  next
    case (IO out c gpv input)
    with IO.hyps have "IO out (λinput. c input  f)  set_spmf (the_gpv (gpv  f))"
      by(auto intro: rev_bexI)
    with IO.prems have "finite_gpv (c input  f)" using IO.hyps(2) by(rule finite_gpvD)
    thus ?case by(rule IO.IH)
  qed
  show ?lhs if "finite_gpv gpv" "xresults_gpv  gpv. finite_gpv (f x)" using that
  proof induction
    case (finite_gpv gpv)
    show ?case
    proof(rule finite_gpvI)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv (gpv  f))" and input: "input  responses_ℐ  out"
      then obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
        and IO: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat)) else
                   return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
        by(auto)
      show "finite_gpv (c input)"
      proof(cases generat)
        case (Pure x)
        with generat IO have "x  results_gpv  gpv" "IO out c  set_spmf (the_gpv (f x))"
          by(auto intro: results_gpv.Pure)
        thus ?thesis using finite_gpv.prems input by(auto dest: finite_gpvD)
      next
        case *: (IO out' c')
        with IO generat finite_gpv.prems input show ?thesis
          by(auto 4 4 intro: finite_gpv.IH results_gpv.IO)
      qed
    qed
  qed
qed

end

context includes lifting_syntax begin

lemma finite_gpv_rel''D1:
  assumes "rel_gpv'' A C R gpv gpv'" and "finite_gpv  gpv" and: "rel_ℐ C R  ℐ'"
  shows "finite_gpv ℐ' gpv'"
using assms(2,1)
proof(induction arbitrary: gpv')
  case (finite_gpv gpv)
  note finite_gpv.prems[transfer_rule]
  show ?case
  proof(rule finite_gpvI)
    fix out' c' input'
    assume IO: "IO out' c'  set_spmf (the_gpv gpv')" and input': "input'  responses_ℐ ℐ' out'"
    have "rel_set (rel_generat A C (R ===> (rel_gpv'' A C R))) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))"
      supply the_gpv_parametric'[transfer_rule] by transfer_prover
    with IO input' responses_ℐ_parametric[THEN rel_funD, OF] obtain out c input
      where "IO out c  set_spmf (the_gpv gpv)" "input  responses_ℐ  out" "rel_gpv'' A C R (c input) (c' input')"
      by(auto 4 3 dest!: rel_setD2 elim!: generat.rel_cases dest: rel_funD)
    then show "finite_gpv ℐ' (c' input')" by(rule finite_gpv.IH)
  qed
qed

lemma finite_gpv_relD1: " rel_gpv A C gpv gpv'; finite_gpv  gpv; rel_ℐ C (=)     finite_gpv  gpv'"
using finite_gpv_rel''D1[of A C "(=)" gpv gpv'  ] by(simp add: rel_gpv_conv_rel_gpv'')

lemma finite_gpv_rel''D2: " rel_gpv'' A C R gpv gpv'; finite_gpv  gpv'; rel_ℐ C R ℐ'    finite_gpv ℐ' gpv"
using finite_gpv_rel''D1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv  ℐ'] by(simp add: rel_gpv''_conversep)

lemma finite_gpv_relD2: " rel_gpv A C gpv gpv'; finite_gpv  gpv'; rel_ℐ C (=)     finite_gpv  gpv"
using finite_gpv_rel''D2[of A C "(=)" gpv gpv'  ] by(simp add: rel_gpv_conv_rel_gpv'')

lemma finite_gpv_parametric': "(rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) finite_gpv finite_gpv"
by(blast dest: finite_gpv_rel''D2 finite_gpv_rel''D1)

lemma finite_gpv_parametric [transfer_rule]: "(rel_ℐ C (=) ===> rel_gpv A C ===> (=)) finite_gpv finite_gpv"
using finite_gpv_parametric'[of C "(=)" A] by(simp add: rel_gpv_conv_rel_gpv'')

end

lemma finite_gpv_map [simp]: "finite_gpv  (map_gpv f id gpv) = finite_gpv  gpv"
using finite_gpv_parametric[of "BNF_Def.Grp UNIV id" "BNF_Def.Grp UNIV f"]
unfolding gpv.rel_Grp by(auto simp add: rel_fun_def BNF_Def.Grp_def eq_commute rel_ℐ_eq)

lemma finite_gpv_assert [simp]: "finite_gpv  (assert_gpv b)"
by(cases b) simp_all

lemma finite_gpv_try [simp]:
  "finite_gpv  (TRY gpv ELSE gpv')  finite_gpv  gpv  (colossless_gpv  gpv  finite_gpv  gpv')"
  (is "?lhs = _")
proof(intro iffI conjI; (elim conjE disjE)?)
  show 1: "finite_gpv  gpv" if ?lhs using that
  proof(induction gpv''"TRY gpv ELSE gpv'" arbitrary: gpv)
    case (finite_gpv gpv)
    show ?case
    proof(rule finite_gpvI)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv gpv)" and input: "input  responses_ℐ  out"
      from IO have "IO out (λinput. TRY c input ELSE gpv')  set_spmf (the_gpv (TRY gpv ELSE gpv'))"
        by(auto simp add: image_image generat.map_comp o_def intro: rev_image_eqI)
      thus "finite_gpv  (c input)" using input by(rule finite_gpv.hyps) simp
    qed
  qed
  have "finite_gpv  gpv'" if "?lhs" "¬ colossless_gpv  gpv" using that
  proof(induction gpv''"TRY gpv ELSE gpv'" arbitrary: gpv)
    case (finite_gpv gpv)
    show ?case
    proof(cases "lossless_spmf (the_gpv gpv)")
      case True
      have "out c input. IO out c  set_spmf (the_gpv gpv)  input  responses_ℐ  out  ¬ colossless_gpv  (c input)"
        using finite_gpv.prems by(rule contrapos_np)(auto intro: colossless_gpvI simp add: True)
      then obtain out c input where IO: "IO out c  set_spmf (the_gpv gpv)"
        and co': "¬ colossless_gpv  (c input)" 
        and input: "input  responses_ℐ  out" by blast
      from IO have "IO out (λinput. TRY c input ELSE gpv')  set_spmf (the_gpv (TRY gpv ELSE gpv'))"
        by(auto simp add: image_image generat.map_comp o_def intro: rev_image_eqI)
      with co' show ?thesis using input by(blast intro: finite_gpv.hyps(2))
    next
      case False
      show ?thesis
      proof(rule finite_gpvI)
        fix out c input
        assume IO: "IO out c  set_spmf (the_gpv gpv')" and input: "input  responses_ℐ  out"
        from IO False have "IO out c  set_spmf (the_gpv (TRY gpv ELSE gpv'))" by(auto intro: rev_image_eqI)
        then show "finite_gpv  (c input)" using input by(rule finite_gpv.hyps)
      qed
    qed
  qed
  then show "colossless_gpv  gpv  finite_gpv  gpv'" if ?lhs using that by blast
  
  show ?lhs if "finite_gpv  gpv" "finite_gpv  gpv'" using that(1)
  proof induction
    case (finite_gpv gpv)
    show ?case
    proof
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv (TRY gpv ELSE gpv'))"
        and input: "input  responses_ℐ  out"
      then consider (gpv) c' where "IO out c'  set_spmf (the_gpv gpv)" "c = (λinput. TRY c' input ELSE gpv')"
        | (gpv') "IO out c  set_spmf (the_gpv gpv')" by(auto split: if_split_asm)
      then show "finite_gpv  (c input)" using input
        by cases(auto intro: finite_gpv.IH finite_gpvD[OF that(2)])
    qed
  qed
  show ?lhs if "finite_gpv  gpv" "colossless_gpv  gpv" using that
  proof induction
    case (finite_gpv gpv)
    show ?case
      by(rule finite_gpvI)(use finite_gpv.prems in fastforce split: if_split_asm dest: colossless_gpvD intro: finite_gpv.IH›)
  qed
qed

lemma lossless_gpv_conv_finite:
  "lossless_gpv  gpv  finite_gpv  gpv  colossless_gpv  gpv"
  (is "?loss  ?fin  ?co")
proof(intro iffI conjI; (elim conjE)?)
  show ?fin if ?loss using that by induction(auto intro: finite_gpvI)
  show ?co if ?loss using that by induction(auto intro: colossless_gpvI)
  show ?loss if ?fin ?co using that
  proof induction
    case (finite_gpv gpv)
    from finite_gpv.prems finite_gpv.IH show ?case
      by cases(auto intro: lossless_gpvI)
  qed
qed

lemma colossless_gpv_try [simp]:
  "colossless_gpv  (TRY gpv ELSE gpv')  colossless_gpv  gpv  colossless_gpv  gpv'"
  (is "?lhs  ?gpv  ?gpv'")
proof(intro iffI disjCI; (elim disjE)?)
  show "?gpv" if ?lhs "¬ ?gpv'" using that(1)
  proof(coinduction arbitrary: gpv)
    case (colossless_gpv gpv)
    have ?lossless_spmf
    proof(rule ccontr)
      assume loss: "¬ ?lossless_spmf"
      with colossless_gpv_lossless_spmfD[OF colossless_gpv(1)]
      have gpv': "lossless_spmf (the_gpv gpv')" by auto
      have "out c input. IO out c  set_spmf (the_gpv gpv')  input  responses_ℐ  out  ¬ colossless_gpv  (c input)"
        using that(2) by(rule contrapos_np)(auto intro: colossless_gpvI gpv')
      then obtain out c input
        where IO: "IO out c  set_spmf (the_gpv gpv')"
        and co': "¬ colossless_gpv  (c input)" 
        and input: "input  responses_ℐ  out" by blast
      from IO loss have "IO out c  set_spmf (the_gpv (TRY gpv ELSE gpv'))"
        by(auto intro: rev_image_eqI)
      with colossless_gpv(1) have "colossless_gpv  (c input)" using input
        by(rule colossless_gpv_continuationD)
      with co' show False by contradiction
    qed
    moreover have ?continuation
    proof(intro strip disjI1; simp)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv gpv)" and input: "input  responses_ℐ  out"
      hence "IO out (λinput. TRY c input ELSE gpv')  set_spmf (the_gpv (TRY gpv ELSE gpv'))"
        by(auto intro: rev_image_eqI)
      with colossless_gpv show "colossless_gpv  (TRY c input ELSE gpv')"
        by(rule colossless_gpv_continuationD)(simp add: input)
    qed
    ultimately show ?case ..
  qed
  show ?lhs if ?gpv'
  proof(coinduction arbitrary: gpv)
    case colossless_gpv
    show ?case using colossless_gpvD[OF that] by(auto 4 3)
  qed
  show ?lhs if ?gpv using that
  proof(coinduction arbitrary: gpv)
    case colossless_gpv
    show ?case using colossless_gpvD[OF colossless_gpv] by(auto 4 3)
  qed
qed

lemma lossless_gpv_try [simp]:
  "lossless_gpv  (TRY gpv ELSE gpv')  
   finite_gpv  gpv  (lossless_gpv  gpv  lossless_gpv  gpv')"
by(auto simp add: lossless_gpv_conv_finite)

lemma interaction_any_bounded_by_imp_finite:
  assumes "interaction_any_bounded_by gpv (enat n)"
  shows "finite_gpv ℐ_full gpv"
using assms 
proof(induction n arbitrary: gpv)
  case 0
  then show ?case by(auto intro: finite_gpv.intros dest: interaction_bounded_by_contD simp add: zero_enat_def[symmetric])
next
  case (Suc n)
  from Suc.prems show ?case unfolding eSuc_enat[symmetric]
    by(auto 4 4 intro: finite_gpv.intros Suc.IH dest: interaction_bounded_by_contD)
qed

lemma finite_restrict_gpvI [simp]: "finite_gpv ℐ' gpv  finite_gpv ℐ' (restrict_gpv  gpv)"
by(induction rule: finite_gpv_induct)(rule finite_gpvI; clarsimp split: option.split_asm; split generat.split_asm; clarsimp split: if_split_asm simp add: in_set_spmf)

lemma interaction_bounded_by_exec_gpv_bad_count:
  fixes count and bad and n :: enat and k :: real
  assumes bound: "interaction_bounded_by consider gpv n"
  and good: "¬ bad s"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); consider x; x  outs_ℐ    count s'  Suc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); ¬ consider x; x  outs_ℐ    count s'  count s"
  and bad: "s' x.  ¬ bad s'; count s' < n + count s; consider x; x  outs_ℐ    spmf (map_spmf (bad  snd) (callee s' x)) True  k"
  and "consider": "s x y s'.  (y, s')  set_spmf (callee s x); ¬ bad s; bad s'; x  outs_ℐ    consider x"
  and k_nonneg: "k  0"
  and WT_gpv: " ⊢g gpv "
  and WT_callee: "s.  ⊢c callee s "
  shows "spmf (map_spmf (bad  snd) (exec_gpv callee gpv s)) True  ennreal k * n"
using bound good bad WT_gpv
proof(induction arbitrary: gpv s n rule: exec_gpv_fixp_induct)
  case adm show ?case by(rule cont_intro ccpo_class.admissible_leI)+
  case bottom show ?case using k_nonneg by(simp add: zero_ereal_def[symmetric])
next
  case (step exec_gpv')
  let ?M = "restrict_space (measure_spmf (the_gpv gpv)) {IO out c|out c. True}"
  have "ennreal (spmf (map_spmf (bad  snd) (bind_spmf (the_gpv gpv) (case_generat (λx. return_spmf (x, s)) (λout c. bind_spmf (callee s out) (λ(x, y). exec_gpv' (c x) y))))) True) =
    ennreal (spmf (bind_spmf (the_gpv gpv) (λgenerat. case generat of Pure x  return_spmf (bad s) |
       IO out rpv  bind_spmf (callee s out) (λ(x, s'). map_spmf (bad  snd) (exec_gpv' (rpv x) s')))) True)"
    (is "_ = ennreal (spmf (bind_spmf _ (case_generat _ ?io)) _)")
    by(simp add: map_spmf_bind_spmf o_def generat.case_distrib[where h="map_spmf _"] split_def cong del: generat.case_cong_weak)
  also have " = + generat. + (x, s'). spmf (map_spmf (bad  snd) (exec_gpv' (continuation generat x) s')) True measure_spmf (callee s (output generat)) ?M"
    using step.prems(2) by(auto simp add: ennreal_spmf_bind nn_integral_restrict_space intro!: nn_integral_cong split: generat.split)
  also have "  + generat. + (x, s'). (if bad s' then 1 else ennreal k * (if consider (output generat) then n - 1 else n)) measure_spmf (callee s (output generat)) ?M"
  proof(clarsimp intro!: nn_integral_mono_AE simp add: AE_restrict_space_iff split del: if_split cong del: if_cong)
    show "ennreal (spmf (map_spmf (bad  snd) (exec_gpv' (rpv ret) s')) True)
           (if bad s' then 1 else ennreal k * ennreal_of_enat (if consider out then n - 1 else n))"
      if IO: "IO out rpv  set_spmf (the_gpv gpv)"
      and call: "(ret, s')  set_spmf (callee s out)"
      for out rpv ret s'
    proof(cases "bad s'")
      case True
      then show ?thesis by(simp add: pmf_le_1)
    next
      case False
      let ?n' = "if consider out then n - 1 else n"
      have out: "out  outs_ℐ " using IO step.prems(4) by(simp add: WT_gpv_OutD)
      have bound': "interaction_bounded_by consider (rpv ret) ?n'"
        using interaction_bounded_by_contD[OF step.prems(1) IO]
              interaction_bounded_by_contD_ignore[OF step.prems(1) IO] by(auto)
      have "ret  responses_ℐ  out" using WT_callee call out by(rule WT_calleeD)
      with step.prems(4) IO have WT': " ⊢g rpv ret " by(rule WT_gpv_ContD)
      have bad':  "spmf (map_pmf (map_option (bad  snd)) (callee s'' x)) True  k"
        if "¬ bad s''" and count': "count s'' < ?n' + count s'" and "consider x" and "x  outs_ℐ "
        for s'' x using ¬ bad s'' _ consider x x  outs_ℐ 
      proof(rule step.prems)
        show "count s'' < n + count s"
        proof(cases "consider out")
          case True
          with count[OF call True out] count' interaction_bounded_by_contD[OF step.prems(1) IO, of undefined]
          show ?thesis by(cases n)(auto simp add: one_enat_def)
        next
          case False
          with ignore[OF call _ out] count' show ?thesis by(cases n)auto
        qed
      qed
      from step.IH[OF bound' False this] False WT' show ?thesis by(auto simp add: o_def)
    qed
  qed
  also have " = + generat. + b. indicator {True} b + ennreal k * (if consider (output generat) then n - 1 else n) * indicator {False} b measure_spmf (map_spmf (bad  snd) (callee s (output generat))) ?M"
    (is "_ = + generat. + _. _ ?O' generat _")
    by(auto intro!: nn_integral_cong)
  also have " = + generat. (+ b. indicator {True} b ?O' generat) + ennreal k * (if consider (output generat) then n - 1 else n) * + b. indicator {False} b ?O' generat ?M"
    by(subst nn_integral_add)(simp_all add: k_nonneg nn_integral_cmult o_def)
  also have " = + generat. ennreal (spmf (map_spmf (bad  snd) (callee s (output generat))) True) + ennreal k * (if consider (output generat) then n - 1 else n) * spmf (map_spmf (bad  snd) (callee s (output generat))) False ?M"
    by(simp del: nn_integral_map_spmf add: emeasure_spmf_single ereal_of_enat_mult)
  also have "  + generat. ennreal k * n ?M"
  proof(intro nn_integral_mono_AE, clarsimp intro!: nn_integral_mono_AE simp add: AE_restrict_space_iff not_is_Pure_conv split del: if_split)
    fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    with step.prems(4) have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    show "spmf (map_spmf (bad  snd) (callee s out)) True +
           ennreal k * (if consider out then n - 1 else n) * spmf (map_spmf (bad  snd) (callee s out)) False
           ennreal k * n"
    proof(cases "consider out")
      case True
      with IO have "n > 0" using interaction_bounded_by_contD[OF step.prems(1)] by(blast dest: interaction_bounded_by_contD)
      have "spmf (map_spmf (bad  snd) (callee s out)) True  k" (is "?o True  _")
        using ¬ bad s True n > 0 out by(intro step.prems)(simp)
      hence "ennreal (?o True)  k" using k_nonneg by(simp del: o_apply)
      hence "?o True + ennreal k * (n - 1) * ?o False  ennreal k + ennreal k * (n - 1) * ennreal 1"
        by(rule add_mono)(rule mult_left_mono, simp_all add: pmf_le_1 k_nonneg)
      also have "  ennreal k * n" using n > 0
        by(cases n)(auto simp add: zero_enat_def ennreal_top_mult gr0_conv_Suc eSuc_enat[symmetric] field_simps)
      finally show ?thesis using True by(simp del: o_apply add: ereal_of_enat_mult)
    next
      case False
      hence "spmf (map_spmf (bad  snd) (callee s out)) True = 0" using ¬ bad s out
        unfolding spmf_eq_0_set_spmf by(auto dest: "consider")
      with False k_nonneg pmf_le_1[of "map_spmf (bad  snd) (callee s out)" "Some False"]
      show ?thesis by(simp add: mult_left_mono[THEN order_trans, where ?b1=1])
    qed
  qed
  also have "  ennreal k * n"
    by(simp add: k_nonneg emeasure_restrict_space measure_spmf.emeasure_eq_measure space_restrict_space measure_spmf.subprob_measure_le_1 mult_left_mono[THEN order_trans, where ?b1=1])
  finally show ?case by(simp del: o_apply)
qed

context callee_invariant_on begin

lemma interaction_bounded_by_exec_gpv_bad_count:
  includes lifting_syntax
  fixes count and bad and n :: enat
  assumes bound: "interaction_bounded_by consider gpv n"
  and I: "I s"
  and good: "¬ bad s"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); I s; consider x; x  outs_ℐ    count s'  Suc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ consider x; x  outs_ℐ    count s'  count s"
  and bad: "s' x.  I s'; ¬ bad s'; count s' < n + count s; consider x; x  outs_ℐ    spmf (map_spmf (bad  snd) (callee s' x)) True  k"
  and "consider": "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ bad s; bad s'; x  outs_ℐ    consider x"
  and k_nonneg: "k  0"
  and WT_gpv: " ⊢g gpv "
  shows "spmf (map_spmf (bad  snd) (exec_gpv callee gpv s)) True  ennreal k * n"
proof -
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr  λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp
    
    let ?C = "eq_onp (λx. x  outs_ℐ )"

    define callee' where "callee'  (Rep ---> id ---> map_spmf (map_prod id Abs)) callee"
    have [transfer_rule]: "(cr ===> ?C ===> rel_spmf (rel_prod (=) cr)) callee callee'"
      by(auto simp add: callee'_def rel_fun_def cr_def spmf_rel_map prod.rel_map td.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td.Rep[simplified] dest: callee_invariant)
    define s' where "s'  Abs s"
    have [transfer_rule]: "cr s s'" using I by(simp add: cr_def s'_def td.Abs_inverse)
    define bad' where "bad'  (Rep ---> id) bad"
    have [transfer_rule]: "(cr ===> (=)) bad bad'" by(simp add: rel_fun_def bad'_def cr_def)
    define count' where "count'  (Rep ---> id) count"
    have [transfer_rule]: "(cr ===> (=)) count count'" by(simp add: rel_fun_def count'_def cr_def)

    have [transfer_rule]: "(?C ===> (=)) consider consider" by(simp add: eq_onp_def rel_fun_def)
    have [transfer_rule]: "rel_ℐ ?C (=)  "
      by(rule rel_ℐI)(auto simp add: rel_set_eq set_relator_eq_onp eq_onp_same_args dest: eq_onp_to_eq)
    note [transfer_rule] = bi_unique_eq_onp bi_unique_eq

    define gpv' where "gpv'  restrict_gpv  gpv"
    have [transfer_rule]: "rel_gpv (=) ?C gpv' gpv'"
      by(fold eq_onp_top_eq_eq)(auto simp add: gpv.rel_eq_onp eq_onp_same_args pred_gpv_def gpv'_def dest: in_outs'_restrict_gpvD)

    have "interaction_bounded_by consider gpv' n" using bound by(simp add: gpv'_def)
    moreover have "¬ bad' s'" using good by transfer
    moreover have [rule_format, rotated]:
      "s y s'. x  outs_ℐ . (y, s')  set_spmf (callee' s x)  consider x  count' s'  Suc (count' s)"
      by(transfer fixing: "consider")(blast intro: count)
    moreover have [rule_format, rotated]:
      "s y s'. x  outs_ℐ . (y, s')  set_spmf (callee' s x)  ¬ consider x  count' s'  count' s"
      by(transfer fixing: "consider")(blast intro: ignore)
    moreover have [rule_format, rotated]: 
      "s''. x  outs_ℐ . ¬ bad' s''  count' s'' < n + count' s'  consider x  spmf (map_spmf (bad'  snd) (callee' s'' x)) True  k"
      by(transfer fixing: "consider" k n)(blast intro: bad)
    moreover have [rule_format, rotated]: 
      "s y s'. x  outs_ℐ . (y, s')  set_spmf (callee' s x)  ¬ bad' s  bad' s'  consider x"
      by(transfer fixing: "consider")(blast intro: "consider")
    moreover note k_nonneg
    moreover have " ⊢g gpv' " by(simp add: gpv'_def)
    moreover have "s.  ⊢c callee' s " by transfer(rule WT_callee)
    ultimately have **: "spmf (map_spmf (bad'  snd) (exec_gpv callee' gpv' s')) True  ennreal k * n"
      by(rule interaction_bounded_by_exec_gpv_bad_count)
    have [transfer_rule]: "((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) callee callee"
      by(simp add: rel_fun_def eq_onp_def prod.rel_eq)
    have "spmf (map_spmf (bad  snd) (exec_gpv callee gpv' s)) True  ennreal k * n" using **
      by(transfer)
    also have "exec_gpv callee gpv' s = exec_gpv callee gpv s"
      unfolding gpv'_def using WT_gpv I by(rule exec_gpv_restrict_gpv_invariant)
    finally have ?thesis . }
  from this[cancel_type_definition] I show ?thesis by blast
qed

lemma interaction_bounded_by'_exec_gpv_bad_count:
  fixes count and bad and n :: nat
  assumes bound: "interaction_bounded_by' consider gpv n"
  and I: "I s"
  and good: "¬ bad s"
  and count: "s x y s'.  (y, s')  set_spmf (callee s x); I s; consider x; x  outs_ℐ    count s'  Suc (count s)"
  and ignore: "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ consider x; x  outs_ℐ    count s'  count s"
  and bad: "s' x.  I s'; ¬ bad s'; count s' < n + count s; consider x; x  outs_ℐ    spmf (map_spmf (bad  snd) (callee s' x)) True  k"
  and "consider": "s x y s'.  (y, s')  set_spmf (callee s x); I s; ¬ bad s; bad s'; x  outs_ℐ    consider x"
  and k_nonneg: "k  0"
  and WT_gpv: " ⊢g gpv "
  shows "spmf (map_spmf (bad  snd) (exec_gpv callee gpv s)) True  k * n"
apply(subst ennreal_le_iff[symmetric], simp_all add: k_nonneg ennreal_mult ennreal_real_conv_ennreal_of_enat del: ennreal_of_enat_enat ennreal_le_iff)
apply(rule interaction_bounded_by_exec_gpv_bad_count[OF bound I _ count ignore bad "consider" k_nonneg WT_gpv, OF good])
apply simp_all
done

lemma interaction_bounded_by_exec_gpv_bad:
  assumes "interaction_any_bounded_by gpv n"
  and "I s" "¬ bad s"
  and bad: "s x.  I s; ¬ bad s; x  outs_ℐ    spmf (map_spmf (bad  snd) (callee s x)) True  k"
  and k_nonneg: "0  k"
  and WT_gpv: " ⊢g gpv "
  shows "spmf (map_spmf (bad  snd) (exec_gpv callee gpv s)) True  k * n"
using interaction_bounded_by_exec_gpv_bad_count[where bad=bad, OF assms(1) assms(2-3), where ?count = "λ_. 0", OF _ _ bad _ k_nonneg] k_nonneg WT_gpv
by(simp add: ennreal_real_conv_ennreal_of_enat[symmetric] ennreal_mult[symmetric] del: ennreal_of_enat_enat)

end

end

Theory Computational_Model

(* Title: Computational_Model.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Oracle combinators›

theory Computational_Model imports 
  Generative_Probabilistic_Value
begin

type_synonym security = nat
type_synonym advantage = "security  real"

type_synonym (, 'call, 'ret) oracle' = "  'call  ('ret × ) spmf"
type_synonym (, 'call, 'ret) "oracle" = "security  (, 'call, 'ret) oracle' × "

print_translation ― ‹pretty printing for @{typ "(, 'call, 'ret) oracle"} let
    fun tr' [Const (@{type_syntax nat}, _), 
      Const (@{type_syntax prod}, _) $ 
        (Const (@{type_syntax fun}, _) $ s1 $ 
          (Const (@{type_syntax fun}, _) $ call $
            (Const (@{type_syntax pmf}, _) $
              (Const (@{type_syntax option}, _) $
                (Const (@{type_syntax prod}, _) $ ret $ s2))))) $
        s3] =
      if s1 = s2 andalso s1 = s3 then Syntax.const @{type_syntax oracle} $ s1 $ call $ ret
      else raise Match;
  in [(@{type_syntax "fun"}, K tr')]
  end
typ "(, 'call, 'ret) oracle"

subsection ‹Shared state›

context includes ℐ.lifting lifting_syntax begin

lift_definition plus_ℐ :: "('out, 'ret) ('out', 'ret') ('out + 'out', 'ret + 'ret') ℐ" (infix "" 500)
is "λresp1 resp2. λout. case out of Inl out'  Inl ` resp1 out' | Inr out'  Inr ` resp2 out'" .

lemma plus_ℐ_sel [simp]:
  shows outs_plus_ℐ: "outs_ℐ (plus_ℐ ℐl ℐr) = outs_ℐ ℐl <+> outs_ℐ ℐr"
  and responses_plus_ℐ_Inl: "responses_ℐ (plus_ℐ ℐl ℐr) (Inl x) = Inl ` responses_ℐ ℐl x"
  and responses_plus_ℐ_Inr: "responses_ℐ (plus_ℐ ℐl ℐr) (Inr y) = Inr ` responses_ℐ ℐr y"
by(transfer; auto split: sum.split_asm; fail)+

lemma vimage_Inl_Plus [simp]: "Inl -` (A <+> B) = A" 
  and vimage_Inr_Plus [simp]: "Inr -` (A <+> B) = B"
by auto

lemma vimage_Inl_image_Inr: "Inl -` Inr ` A = {}"
  and vimage_Inr_image_Inl: "Inr -` Inl ` A = {}"
by auto

lemma plus_ℐ_parametric [transfer_rule]:
  "(rel_ℐ C R ===> rel_ℐ C' R' ===> rel_ℐ (rel_sum C C') (rel_sum R R')) plus_ℐ plus_ℐ"
apply(rule rel_funI rel_ℐI)+
subgoal premises [transfer_rule] by(simp; rule conjI; transfer_prover)
apply(erule rel_sum.cases; clarsimp simp add: inj_vimage_image_eq vimage_Inl_image_Inr empty_transfer vimage_Inr_image_Inl)
subgoal premises [transfer_rule] by transfer_prover
subgoal premises [transfer_rule] by transfer_prover
done

lifting_update ℐ.lifting
lifting_forget ℐ.lifting

lemma ℐ_trivial_plus_ℐ [simp]: "ℐ_trivial (1  2)  ℐ_trivial 1  ℐ_trivial 2"
by(auto simp add: ℐ_trivial_def)

end

lemma map_ℐ_plus_ℐ [simp]: 
  "map_ℐ (map_sum f1 f2) (map_sum g1 g2) (ℐ1  ℐ2) = map_ℐ f1 g1 ℐ1  map_ℐ f2 g2 ℐ2"
proof(rule ℐ_eqI[OF Set.set_eqI], goal_cases)
  case (1 x)
  then show ?case by(cases x) auto
qed (auto simp add: image_image)

lemma le_plus_ℐ_iff [simp]:
  "ℐ1  ℐ2  ℐ1'  ℐ2'  ℐ1  ℐ1'  ℐ2  ℐ2'"
  by(auto 4 4 simp add: le_ℐ_def dest: bspec[where x="Inl _"] bspec[where x="Inr _"])

lemma ℐ_full_le_plus_ℐ: "ℐ_full  plus_ℐ ℐ1 ℐ2" if "ℐ_full  ℐ1" "ℐ_full  ℐ2"
  using that by(auto simp add: le_ℐ_def top_unique)

lemma plus_ℐ_mono: "plus_ℐ ℐ1 ℐ2  plus_ℐ ℐ1' ℐ2'" if "ℐ1  ℐ1'" "ℐ2  ℐ2'" 
  using that by(fastforce simp add: le_ℐ_def)

context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
  and s :: "'s"
begin

primrec plus_oracle :: "'a + 'c  (('b + 'd) × 's) spmf"
where
  "plus_oracle (Inl a) = map_spmf (apfst Inl) (left s a)"
| "plus_oracle (Inr b) = map_spmf (apfst Inr) (right s b)"

lemma lossless_plus_oracleI [intro, simp]:
  " a. x = Inl a  lossless_spmf (left s a); 
     b. x = Inr b  lossless_spmf (right s b) 
   lossless_spmf (plus_oracle x)"
by(cases x) simp_all

lemma plus_oracle_split:
  "P (plus_oracle lr) 
  (x. lr = Inl x  P (map_spmf (apfst Inl) (left s x))) 
  (y. lr = Inr y  P (map_spmf (apfst Inr) (right s y)))"
by(cases lr) auto

lemma plus_oracle_split_asm:
  "P (plus_oracle lr) 
  ¬ ((x. lr = Inl x  ¬ P (map_spmf (apfst Inl) (left s x))) 
     (y. lr = Inr y  ¬ P (map_spmf (apfst Inr) (right s y))))"
by(cases lr) auto

end

notation plus_oracle (infix "O" 500)

context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
begin

lemma WT_plus_oracleI [intro!]:
  " ℐl ⊢c left s ; ℐr ⊢c right s    ℐl  ℐr ⊢c (left O right) s "
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)

lemma WT_plus_oracleD1:
  assumes "ℐl  ℐr ⊢c (left O right) s  " (is "?ℐ ⊢c ?callee s ")
  shows "ℐl ⊢c left s "
proof(rule WT_calleeI)
  fix call ret s'
  assume "call  outs_ℐ ℐl" "(ret, s')  set_spmf (left s call)"
  hence "(Inl ret, s')  set_spmf (?callee s (Inl call))" "Inl call  outs_ℐ (ℐl  ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inl ret  responses_ℐ ?ℐ (Inl call)" by(rule WT_calleeD[OF assms])
  then show "ret  responses_ℐ ℐl call" by(simp add: inj_image_mem_iff)
qed

lemma WT_plus_oracleD2:
  assumes "ℐl  ℐr ⊢c (left O right) s  " (is "?ℐ ⊢c ?callee s ")
  shows "ℐr ⊢c right s "
proof(rule WT_calleeI)
  fix call ret s'
  assume "call  outs_ℐ ℐr" "(ret, s')  set_spmf (right s call)"
  hence "(Inr ret, s')  set_spmf (?callee s (Inr call))" "Inr call  outs_ℐ (ℐl  ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inr ret  responses_ℐ ?ℐ (Inr call)" by(rule WT_calleeD[OF assms])
  then show "ret  responses_ℐ ℐr call" by(simp add: inj_image_mem_iff)
qed

lemma WT_plus_oracle_iff [simp]: "ℐl  ℐr ⊢c (left O right) s   ℐl ⊢c left s   ℐr ⊢c right s "
by(blast dest: WT_plus_oracleD1 WT_plus_oracleD2)

lemma callee_invariant_on_plus_oracle [simp]:
  "callee_invariant_on (left O right) I (ℐl  ℐr) 
   callee_invariant_on left I ℐl  callee_invariant_on right I ℐr"
   (is "?lhs  ?rhs")
proof(intro iffI conjI)
  assume ?lhs
  then interpret plus: callee_invariant_on "left O right" I "ℐl  ℐr" .
  show "callee_invariant_on left I ℐl"
  proof
    fix s x y s'
    assume "(y, s')  set_spmf (left s x)" and "I s" and "x  outs_ℐ ℐl"
    then have "(Inl y, s')  set_spmf ((left O right) s (Inl x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using I s by(rule plus.callee_invariant)(simp add: x  outs_ℐ ℐl)
  next
    show "ℐl ⊢c left s " if "I s" for s using plus.WT_callee[OF that] by simp
  qed
  show "callee_invariant_on right I ℐr"
  proof
    fix s x y s'
    assume "(y, s')  set_spmf (right s x)" and "I s" and "x  outs_ℐ ℐr"
    then have "(Inr y, s')  set_spmf ((left O right) s (Inr x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using I s by(rule plus.callee_invariant)(simp add: x  outs_ℐ ℐr)
  next
    show "ℐr ⊢c right s " if "I s" for s using plus.WT_callee[OF that] by simp
  qed
next
  assume ?rhs
  interpret left: callee_invariant_on left I ℐl using ?rhs by simp
  interpret right: callee_invariant_on right I ℐr using ?rhs by simp
  show ?lhs
  proof
    fix s x y s'
    assume "(y, s')  set_spmf ((left O right) s x)" and "I s" and "x  outs_ℐ (ℐl  ℐr)"
    then have "(projl y, s')  set_spmf (left s (projl x))  projl x  outs_ℐ ℐl 
      (projr y, s')  set_spmf (right s (projr x))  projr x  outs_ℐ ℐr"
      by (cases x)  auto
    then show "I s'" using I s 
      by (auto dest: left.callee_invariant right.callee_invariant)
  next
    show "ℐl  ℐr ⊢c (left O right) s " if "I s" for s 
      using left.WT_callee[OF that] right.WT_callee[OF that] by simp
  qed
qed

lemma callee_invariant_plus_oracle [simp]:
  "callee_invariant (left O right) I 
   callee_invariant left I  callee_invariant right I"
  (is "?lhs   ?rhs")
proof -
  have "?lhs  callee_invariant_on (left O right) I (ℐ_full  ℐ_full)"
    by(rule callee_invariant_on_cong)(auto split: plus_oracle_split_asm)
  also have "  ?rhs" by(rule callee_invariant_on_plus_oracle)
  finally show ?thesis .
qed

lemma plus_oracle_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> A ===> rel_spmf (rel_prod B S))
   ===> (S ===> C ===> rel_spmf (rel_prod D S))
   ===> S ===> rel_sum A C ===> rel_spmf (rel_prod (rel_sum B D) S))
   plus_oracle plus_oracle"
unfolding plus_oracle_def[abs_def] by transfer_prover

lemma rel_spmf_plus_oracle:
  " q1' q2'.  q1 = Inl q1'; q2 = Inl q2'   rel_spmf (rel_prod B S) (left1 s1 q1') (left2 s2 q2');
    q1' q2'.  q1 = Inr q1'; q2 = Inr q2'   rel_spmf (rel_prod D S) (right1 s1 q1') (right2 s2 q2');
    S s1 s2; rel_sum A C q1 q2 
   rel_spmf (rel_prod (rel_sum B D) S) ((left1 O right1) s1 q1) ((left2 O right2) s2 q2)"
apply(erule rel_sum.cases; clarsimp)
 apply(erule meta_allE)+
 apply(erule meta_impE, rule refl)+
 subgoal premises [transfer_rule] by transfer_prover
apply(erule meta_allE)+
apply(erule meta_impE, rule refl)+
subgoal premises [transfer_rule] by transfer_prover
done

end

subsection ‹Shared state with aborts›

context
  fixes left :: "('s, 'a, 'b option) oracle'"
  and right :: "('s,'c, 'd option) oracle'"
  and s :: "'s"
begin

primrec plus_oracle_stop :: "'a + 'c  (('b + 'd) option × 's) spmf"
where
  "plus_oracle_stop (Inl a) = map_spmf (apfst (map_option Inl)) (left s a)"
| "plus_oracle_stop (Inr b) = map_spmf (apfst (map_option Inr)) (right s b)"

lemma lossless_plus_oracle_stopI [intro, simp]:
  " a. x = Inl a  lossless_spmf (left s a); 
     b. x = Inr b  lossless_spmf (right s b) 
   lossless_spmf (plus_oracle_stop x)"
by(cases x) simp_all

lemma plus_oracle_stop_split:
  "P (plus_oracle_stop lr) 
  (x. lr = Inl x  P (map_spmf (apfst (map_option Inl)) (left s x))) 
  (y. lr = Inr y  P (map_spmf (apfst (map_option Inr)) (right s y)))"
by(cases lr) auto

lemma plus_oracle_stop_split_asm:
  "P (plus_oracle_stop lr) 
  ¬ ((x. lr = Inl x  ¬ P (map_spmf (apfst (map_option Inl)) (left s x))) 
     (y. lr = Inr y  ¬ P (map_spmf (apfst (map_option Inr)) (right s y))))"
by(cases lr) auto

end

notation plus_oracle_stop (infix "OS" 500)

subsection ‹Disjoint state›

context
  fixes left :: "('s1, 'a, 'b) oracle'"
  and right :: "('s2, 'c, 'd) oracle'"
begin

fun parallel_oracle :: "('s1 × 's2, 'a + 'c, 'b + 'd) oracle'"
where
  "parallel_oracle (s1, s2) (Inl a) = map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)"
| "parallel_oracle (s1, s2) (Inr b) = map_spmf (map_prod Inr (Pair s1)) (right s2 b)"

lemma parallel_oracle_def:
  "parallel_oracle = (λ(s1, s2). case_sum (λa. map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)) (λb. map_spmf (map_prod Inr (Pair s1)) (right s2 b)))"
by(auto intro!: ext split: sum.split)

lemma lossless_parallel_oracle [simp]:
  "lossless_spmf (parallel_oracle s12 xy) 
   (x. xy = Inl x  lossless_spmf (left (fst s12) x)) 
   (y. xy = Inr y  lossless_spmf (right (snd s12) y))"
by(cases s12; cases xy) simp_all

lemma parallel_oracle_split:
  "P (parallel_oracle s1s2 lr) 
  (s1 s2 x. s1s2 = (s1, s2)  lr = Inl x  P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) 
  (s1 s2 y. s1s2 = (s1, s2)  lr = Inr y  P (map_spmf (map_prod Inr (Pair s1)) (right s2 y)))"
by(cases s1s2; cases lr) auto

lemma parallel_oracle_split_asm:
  "P (parallel_oracle s1s2 lr) 
  ¬ ((s1 s2 x. s1s2 = (s1, s2)  lr = Inl x  ¬ P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) 
     (s1 s2 y. s1s2 = (s1, s2)  lr = Inr y  ¬ P (map_spmf (map_prod Inr (Pair s1)) (right s2 y))))"
by(cases s1s2; cases lr) auto

lemma WT_parallel_oracle [intro!, simp]:
  " ℐl ⊢c left sl ; ℐr ⊢c right sr    plus_ℐ ℐl ℐr ⊢c parallel_oracle (sl, sr) "
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)

lemma callee_invariant_parallel_oracleI [simp, intro]:
  assumes "callee_invariant_on left Il ℐl" "callee_invariant_on right Ir ℐr"
  shows "callee_invariant_on parallel_oracle (pred_prod Il Ir) (ℐl  ℐr)"
proof
  interpret left: callee_invariant_on left Il ℐl by fact
  interpret right: callee_invariant_on right Ir ℐr by fact

  show "pred_prod Il Ir s12'"
    if "(y, s12')  set_spmf (parallel_oracle s12 x)" and "pred_prod Il Ir s12" and "x  outs_ℐ (ℐl  ℐr)"
    for s12 x y s12' using that
    by(cases s12; cases s12; cases x)(auto dest: left.callee_invariant right.callee_invariant)

  show "ℐl  ℐr ⊢c local.parallel_oracle s " if "pred_prod Il Ir s" for s using that
    by(cases s)(simp add: left.WT_callee right.WT_callee)
qed

end

lemma parallel_oracle_parametric:
  includes lifting_syntax shows
  "((S1 ===> CALL1 ===> rel_spmf (rel_prod (=) S1)) 
  ===> (S2 ===> CALL2 ===> rel_spmf (rel_prod (=) S2))
  ===> rel_prod S1 S2 ===> rel_sum CALL1 CALL2 ===> rel_spmf (rel_prod (=) (rel_prod S1 S2)))
  parallel_oracle parallel_oracle"
unfolding parallel_oracle_def[abs_def] by (fold relator_eq)transfer_prover

subsection ‹Indexed oracles›

definition family_oracle :: "('i  ('s, 'a, 'b) oracle')  ('i  's, 'i × 'a, 'b) oracle'"
where "family_oracle f s = (λ(i, x). map_spmf (λ(y, s'). (y, s(i := s'))) (f i (s i) x))"

lemma family_oracle_apply [simp]:
  "family_oracle f s (i, x) = map_spmf (apsnd (fun_upd s i)) (f i (s i) x)"
by(simp add: family_oracle_def apsnd_def map_prod_def)

lemma lossless_family_oracle:
  "lossless_spmf (family_oracle f s ix)  lossless_spmf (f (fst ix) (s (fst ix)) (snd ix))"
by(simp add: family_oracle_def split_beta)

subsection ‹State extension›

definition extend_state_oracle :: "('call, 'ret, 's) callee  ('call, 'ret, 's' × 's) callee" ("_" [1000] 1000)
where "extend_state_oracle callee = (λ(s', s) x. map_spmf (λ(y, s). (y, (s', s))) (callee s x))"

lemma extend_state_oracle_simps [simp]:
  "extend_state_oracle callee (s', s) x = map_spmf (λ(y, s). (y, (s', s))) (callee s x)"
by(simp add: extend_state_oracle_def)

context includes lifting_syntax begin
lemma extend_state_oracle_parametric [transfer_rule]:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S' S ===> C ===> rel_spmf (rel_prod R (rel_prod S' S)))
  extend_state_oracle extend_state_oracle"
unfolding extend_state_oracle_def[abs_def] by transfer_prover

lemma extend_state_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) 
  ===> rel_prod2 S ===> C ===> rel_spmf (rel_prod R (rel_prod2 S)))
  (λoracle. oracle) extend_state_oracle"
unfolding extend_state_oracle_def[abs_def]
apply(rule rel_funI)+
apply clarsimp
apply(drule (1) rel_funD)+
apply(auto simp add: spmf_rel_map split_def dest: rel_funD intro: rel_spmf_mono)
done
end

lemma callee_invariant_extend_state_oracle_const [simp]:
  "callee_invariant oracle (λ(s', s). I s')"
by unfold_locales auto

lemma callee_invariant_extend_state_oracle_const':
  "callee_invariant oracle (λs. I (fst s))"
by unfold_locales auto

definition lift_stop_oracle :: "('call, 'ret, 's) callee  ('call, 'ret option, 's) callee"
where "lift_stop_oracle oracle s x = map_spmf (apfst Some) (oracle s x)"

lemma lift_stop_oracle_apply [simp]: "lift_stop_oracle  oracle s x = map_spmf (apfst Some) (oracle s x)"
  by(fact lift_stop_oracle_def)
  
context includes lifting_syntax begin

lemma lift_stop_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> (S ===> C ===> rel_spmf (rel_prod (pcr_Some R) S)))
   (λx. x) lift_stop_oracle"
unfolding lift_stop_oracle_def
apply(rule rel_funI)+
apply(drule (1) rel_funD)+
apply(simp add: spmf_rel_map apfst_def prod.rel_map)
done

end

definition extend_state_oracle2 :: "('call, 'ret, 's) callee  ('call, 'ret, 's × 's') callee" ("_" [1000] 1000)
  where "extend_state_oracle2 callee = (λ(s, s') x. map_spmf (λ(y, s). (y, (s, s'))) (callee s x))"

lemma extend_state_oracle2_simps [simp]:
  "extend_state_oracle2 callee (s, s') x = map_spmf (λ(y, s). (y, (s, s'))) (callee s x)"
  by(simp add: extend_state_oracle2_def)

lemma extend_state_oracle2_parametric [transfer_rule]: includes lifting_syntax shows
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S S' ===> C ===> rel_spmf (rel_prod R (rel_prod S S')))
  extend_state_oracle2 extend_state_oracle2"
  unfolding extend_state_oracle2_def[abs_def] by transfer_prover

lemma callee_invariant_extend_state_oracle2_const [simp]:
  "callee_invariant oracle (λ(s, s'). I s')"
  by unfold_locales auto

lemma callee_invariant_extend_state_oracle2_const':
  "callee_invariant oracle (λs. I (snd s))"
  by unfold_locales auto

lemma extend_state_oracle2_plus_oracle: 
  "extend_state_oracle2 (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle2 oracle1) (extend_state_oracle2 oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed

lemma parallel_oracle_conv_plus_oracle:
  "parallel_oracle oracle1 oracle2 = plus_oracle (oracle1) (oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (auto simp add: spmf.map_comp apfst_def o_def split_def map_prod_def)
qed

lemma map_sum_parallel_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (parallel_oracle oracle1 oracle2)
  = parallel_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed

lemma map_sum_plus_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (plus_oracle oracle1 oracle2)
  = plus_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed

lemma map_rsuml_plus_oracle: includes lifting_syntax shows
  "(id ---> rsuml ---> (map_spmf (map_prod lsumr id))) (oracle1 O (oracle2 O oracle3)) =
   ((oracle1 O oracle2) O oracle3)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inl ql)
    then show ?thesis by(cases ql)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed

lemma map_lsumr_plus_oracle: includes lifting_syntax shows
  "(id ---> lsumr ---> (map_spmf (map_prod rsuml id))) ((oracle1 O oracle2) O oracle3) =
   (oracle1 O (oracle2 O oracle3))"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inr qr)
    then show ?thesis by(cases qr)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed

context includes lifting_syntax begin

definition lift_state_oracle
  :: "(('s  'a  (('b × 't) × 's) spmf)  ('s'  'a  (('b × 't) × 's') spmf)) 
   ('t × 's  'a  ('b × 't × 's) spmf)  ('t × 's'  'a  ('b × 't × 's') spmf)" where
  "lift_state_oracle F oracle = 
   (λ(t, s') a. map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a))"

lemma lift_state_oracle_simps [simp]:
  "lift_state_oracle F oracle (t, s') a = map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a)"
  by(simp add: lift_state_oracle_def)

lemma lift_state_oracle_parametric [transfer_rule]: includes lifting_syntax shows
  "(((S ===> A ===> rel_spmf (rel_prod (rel_prod B T) S)) ===> S' ===> A ===> rel_spmf (rel_prod (rel_prod B T) S'))
  ===> (rel_prod T S ===> A ===> rel_spmf (rel_prod B (rel_prod T S)))
  ===> rel_prod T S' ===> A ===> rel_spmf (rel_prod B (rel_prod T S')))
  lift_state_oracle lift_state_oracle"
  unfolding lift_state_oracle_def map_fun_def o_def by transfer_prover

lemma lift_state_oracle_extend_state_oracle:
  includes lifting_syntax
  assumes "B. Transfer.Rel (((=) ===> (=) ===> rel_spmf (rel_prod B (=))) ===> (=) ===> (=) ===> rel_spmf (rel_prod B (=))) G F"
    (* TODO: implement simproc to discharge parametricity assumptions like this one *)
  shows "lift_state_oracle F (extend_state_oracle oracle) = extend_state_oracle (G oracle)"
  unfolding lift_state_oracle_def extend_state_oracle_def
  apply(clarsimp simp add: fun_eq_iff map_fun_def o_def spmf.map_comp split_def rprodl_def)
  subgoal for t s a
    apply(rule sym)
    apply(fold spmf_rel_eq)
    apply(simp add: spmf_rel_map)
    apply(rule rel_spmf_mono)
     apply(rule assms[unfolded Rel_def, where B="λx (y, z). x = y  z = t", THEN rel_funD, THEN rel_funD, THEN rel_funD])
       apply(auto simp add: rel_fun_def spmf_rel_map intro!: rel_spmf_reflI)
    done
  done

lemma lift_state_oracle_compose: 
  "lift_state_oracle F (lift_state_oracle G oracle) = lift_state_oracle (F  G) oracle"
  by(simp add: lift_state_oracle_def map_fun_def o_def split_def spmf.map_comp)

lemma lift_state_oracle_id [simp]: "lift_state_oracle id = id"
  by(simp add: fun_eq_iff spmf.map_comp o_def)

lemma rprodl_extend_state_oracle: includes lifting_syntax shows
  "(rprodl ---> id ---> map_spmf (map_prod id lprodr)) (extend_state_oracle (extend_state_oracle oracle)) = 
  extend_state_oracle oracle"
  by(simp add: fun_eq_iff spmf.map_comp o_def split_def)

end

section ‹Combining GPVs›

subsection ‹Shared state without interrupts›

context
  fixes left :: "'s  'x1  ('y1 × 's, 'call, 'ret) gpv"
  and right :: "'s  'x2  ('y2 × 's, 'call, 'ret) gpv"
begin

primrec plus_intercept :: "'s  'x1 + 'x2  (('y1 + 'y2) × 's, 'call, 'ret) gpv"
where
  "plus_intercept s (Inl x) = map_gpv (apfst Inl) id (left s x)"
| "plus_intercept s (Inr x) = map_gpv (apfst Inr) id (right s x)"

end

lemma plus_intercept_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod Y1 S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod Y2 S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_sum Y1 Y2) S) C)
  plus_intercept plus_intercept"
unfolding plus_intercept_def[abs_def] by transfer_prover

lemma interaction_bounded_by_plus_intercept [interaction_bound]:
  fixes left right
  shows " x'. x = Inl x'  interaction_bounded_by P (left s x') (n x');
    y. x = Inr y  interaction_bounded_by P (right s y) (m y) 
   interaction_bounded_by P (plus_intercept left right s x) (case x of Inl x  n x | Inr y  m y)"
by(simp split!: sum.split add: interaction_bounded_by_map_gpv_id)

subsection ‹Shared state with interrupts›

context 
  fixes left :: "'s  'x1  ('y1 option × 's, 'call, 'ret) gpv"
  and right :: "'s  'x2  ('y2 option × 's, 'call, 'ret) gpv"
begin

primrec plus_intercept_stop :: "'s  'x1 + 'x2  (('y1 + 'y2) option × 's, 'call, 'ret) gpv"
where
  "plus_intercept_stop s (Inl x) = map_gpv (apfst (map_option Inl)) id (left s x)"
| "plus_intercept_stop s (Inr x) = map_gpv (apfst (map_option Inr)) id (right s x)"

end

lemma plus_intercept_stop_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod (rel_option Y1) S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod (rel_option Y2) S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_option (rel_sum Y1 Y2)) S) C)
  plus_intercept_stop plus_intercept_stop"
unfolding plus_intercept_stop_def by transfer_prover

subsection ‹One-sided shifts›

primcorec (transfer) left_gpv :: "('a, 'out, 'in) gpv  ('a, 'out + 'out', 'in + 'in') gpv" where
  "the_gpv (left_gpv gpv) = 
   map_spmf (map_generat id Inl (λrpv input. case input of Inl input'  left_gpv (rpv input') | _  Fail)) (the_gpv gpv)"

abbreviation left_rpv :: "('a, 'out, 'in) rpv  ('a, 'out + 'out', 'in + 'in') rpv" where
  "left_rpv rpv  λinput. case input of Inl input'  left_gpv (rpv input') | _  Fail"

primcorec (transfer) right_gpv :: "('a, 'out, 'in) gpv  ('a, 'out' + 'out, 'in' + 'in) gpv" where
  "the_gpv (right_gpv gpv) =
   map_spmf (map_generat id Inr (λrpv input. case input of Inr input'  right_gpv (rpv input') | _  Fail)) (the_gpv gpv)"

abbreviation right_rpv :: "('a, 'out, 'in) rpv  ('a, 'out' + 'out, 'in' + 'in) rpv" where
  "right_rpv rpv  λinput. case input of Inr input'  right_gpv (rpv input') | _  Fail"

context 
  includes lifting_syntax
  notes [transfer_rule] = corec_gpv_parametric' Fail_parametric' the_gpv_parametric'
begin

lemmas left_gpv_parametric = left_gpv.transfer

lemma left_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) left_gpv left_gpv"
  unfolding left_gpv_def by transfer_prover

lemmas right_gpv_parametric = right_gpv.transfer

lemma right_gpv_parametric':
  "(rel_gpv'' A C' R' ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) right_gpv right_gpv"
  unfolding right_gpv_def by transfer_prover

end

lemma left_gpv_Done [simp]: "left_gpv (Done x) = Done x"
  by(rule gpv.expand) simp

lemma right_gpv_Done [simp]: "right_gpv (Done x) = Done x"
  by(rule gpv.expand) simp

lemma left_gpv_Pause [simp]:
  "left_gpv (Pause x rpv) = Pause (Inl x) (λinput. case input of Inl input'  left_gpv (rpv input') | _  Fail)"
  by(rule gpv.expand) simp

lemma right_gpv_Pause [simp]:
  "right_gpv (Pause x rpv) = Pause (Inr x) (λinput. case input of Inr input'  right_gpv (rpv input') | _  Fail)"
  by(rule gpv.expand) simp

lemma left_gpv_map: "left_gpv (map_gpv f g gpv) = map_gpv f (map_sum g h) (left_gpv gpv)"
  using left_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)

lemma right_gpv_map: "right_gpv (map_gpv f g gpv) = map_gpv f (map_sum h g) (right_gpv gpv)"
  using right_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)

lemma results'_gpv_left_gpv [simp]: 
  "results'_gpv (left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inl _"])
qed

lemma results'_gpv_right_gpv [simp]: 
  "results'_gpv (right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inr _"])
qed

lemma left_gpv_Inl_transfer: "rel_gpv'' (=) (λl r. l = Inl r) (λl r. l = Inl r) (left_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)

lemma right_gpv_Inr_transfer: "rel_gpv'' (=) (λl r. l = Inr r) (λl r. l = Inr r) (right_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)

lemma exec_gpv_plus_oracle_left: "exec_gpv (plus_oracle oracle1 oracle2) (left_gpv gpv) s = exec_gpv oracle1 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inl r" and R="λl r. l = Inl r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI left_gpv_Inl_transfer)

lemma exec_gpv_plus_oracle_right: "exec_gpv (plus_oracle oracle1 oracle2) (right_gpv gpv) s = exec_gpv oracle2 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inr r" and R="λl r. l = Inr r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI right_gpv_Inr_transfer)

lemma left_gpv_bind_gpv: "left_gpv (bind_gpv gpv f) = bind_gpv (left_gpv gpv) (left_gpv  f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)

lemma inline1_left_gpv:
  "inline1 (λs q. left_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inl (map_prod left_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed

lemma left_gpv_inline: "left_gpv (inline callee gpv s) = inline (λs q. left_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_left_gpv left_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)

lemma right_gpv_bind_gpv: "right_gpv (bind_gpv gpv f) = bind_gpv (right_gpv gpv) (right_gpv  f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)

lemma inline1_right_gpv:
  "inline1 (λs q. right_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inr (map_prod right_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed

lemma right_gpv_inline: "right_gpv (inline callee gpv s) = inline (λs q. right_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_right_gpv right_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)

lemma WT_gpv_left_gpv: "ℐ1 ⊢g gpv   ℐ1  ℐ2 ⊢g left_gpv gpv "
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma WT_gpv_right_gpv: "ℐ2 ⊢g gpv   ℐ1  ℐ2 ⊢g right_gpv gpv "
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma results_gpv_left_gpv [simp]: "results_gpv (ℐ1  ℐ2) (left_gpv gpv) = results_gpv ℐ1 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"left_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma results_gpv_right_gpv [simp]: "results_gpv (ℐ1  ℐ2) (right_gpv gpv) = results_gpv ℐ2 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"right_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma left_gpv_Fail [simp]: "left_gpv Fail = Fail"
  by(rule gpv.expand) auto

lemma right_gpv_Fail [simp]: "right_gpv Fail = Fail"
  by(rule gpv.expand) auto

lemma rsuml_lsumr_left_gpv_left_gpv:"map_gpv' id rsuml lsumr (left_gpv (left_gpv gpv)) = left_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma rsuml_lsumr_left_gpv_right_gpv: "map_gpv' id rsuml lsumr (left_gpv (right_gpv gpv)) = right_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma rsuml_lsumr_right_gpv: "map_gpv' id rsuml lsumr (right_gpv gpv) = right_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma map_gpv'_map_gpv_swap:
  "map_gpv' f g h (map_gpv f' id gpv) = map_gpv (f  f') id (map_gpv' id g h gpv)"
  by(simp add: map_gpv_conv_map_gpv' map_gpv'_comp)

lemma lsumr_rsuml_left_gpv: "map_gpv' id lsumr rsuml (left_gpv gpv) = left_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])

lemma lsumr_rsuml_right_gpv_left_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (left_gpv gpv)) = left_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])

lemma lsumr_rsuml_right_gpv_right_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (right_gpv gpv)) = right_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: rsuml.elims intro: exI[where x=Fail])


lemma in_set_spmf_extend_state_oracle [simp]:
  "x  set_spmf (extend_state_oracle oracle s y) 
   fst (snd x) = fst s  (fst x, snd (snd x))  set_spmf (oracle (snd s) y)"
  by(auto 4 4 simp add: extend_state_oracle_def split_beta intro: rev_image_eqI prod.expand)

lemma extend_state_oracle_plus_oracle: 
  "extend_state_oracle (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle oracle1) (extend_state_oracle oracle2)"
proof ((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed


definition stateless_callee :: "('a  ('b, 'out, 'in) gpv)  ('s  'a  ('b × 's, 'out, 'in) gpv)" where
  "stateless_callee callee s = map_gpv (λb. (b, s)) id  callee"

lemma stateless_callee_parametric': 
  includes lifting_syntax notes [transfer_rule] = map_gpv_parametric' shows
    "((A ===> rel_gpv'' B C R) ===> S ===> A ===> (rel_gpv'' (rel_prod B S) C R))
   stateless_callee stateless_callee"
  unfolding stateless_callee_def by transfer_prover

lemma id_oralce_alt_def: "id_oracle = stateless_callee (λx. Pause x Done)"
  by(simp add: id_oracle_def fun_eq_iff stateless_callee_def)

context
  fixes left :: "'s1  'x1  ('y1 × 's1, 'call1, 'ret1) gpv"
    and right :: "'s2  'x2  ('y2 × 's2, 'call2, 'ret2) gpv"
begin

fun parallel_intercept :: "'s1 × 's2  'x1 + 'x2  (('y1 + 'y2) × ('s1 × 's2), 'call1 + 'call2, 'ret1 + 'ret2) gpv"
  where
    "parallel_intercept (s1, s2) (Inl a) = left_gpv (map_gpv (map_prod Inl (λs1'. (s1', s2))) id (left s1 a))"
  | "parallel_intercept (s1, s2) (Inr b) = right_gpv (map_gpv (map_prod Inr (Pair s1)) id (right s2 b))"

end

end

Theory GPV_Expectation

(* Title: GPV_Expectation.thy
  Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Expectation transformer semantics›

theory GPV_Expectation imports
  Computational_Model
begin

lemma le_enn2realI: " ennreal x  y; y =   x  0   x  enn2real y"
by(cases y) simp_all

lemma enn2real_leD: " enn2real x < y; x     x < ennreal y"
by(cases x)(simp_all add: ennreal_lessI)

lemma ennreal_mult_le_self2I: " y > 0  x  1   x * y  y" for x y :: ennreal
apply(cases x; cases y)
apply(auto simp add: top_unique ennreal_top_mult ennreal_mult[symmetric] intro: ccontr)
using mult_left_le_one_le by force

lemma ennreal_leI: "x  enn2real y  ennreal x  y"
by(cases y) simp_all

lemma enn2real_INF: " A  {}; xA. f x <    enn2real (INF xA. f x) = (INF xA. enn2real (f x))"
apply(rule antisym)
 apply(rule cINF_greatest)
  apply simp
 apply(rule enn2real_mono)
  apply(erule INF_lower)
 apply simp
apply(rule le_enn2realI)
 apply simp_all
apply(rule INF_greatest)
apply(rule ennreal_leI)
apply(rule cINF_lower)
apply(rule bdd_belowI[where m=0])
apply auto
done

lemma monotone_times_ennreal1: "monotone (≤) (≤) (λx. x * y :: ennreal)"
by(auto intro!: monotoneI mult_right_mono)

lemma monotone_times_ennreal2: "monotone (≤) (≤) (λx. y * x :: ennreal)"
by(auto intro!: monotoneI mult_left_mono)

lemma mono2mono_times_ennreal[THEN lfp.mono2mono2, cont_intro, simp]:
  shows monotone_times_ennreal: "monotone (rel_prod (≤) (≤)) (≤) (λ(x, y). x * y :: ennreal)"
by(simp add: monotone_times_ennreal1 monotone_times_ennreal2)

lemma mcont_times_ennreal1: "mcont Sup (≤) Sup (≤) (λy. x * y :: ennreal)"
by(auto intro!: mcontI contI simp add: SUP_mult_left_ennreal[symmetric])

lemma mcont_times_ennreal2: "mcont Sup (≤) Sup (≤) (λy. y * x :: ennreal)"
by(subst mult.commute)(rule mcont_times_ennreal1)

lemma mcont2mcont_times_ennreal [cont_intro, simp]:
  " mcont lub ord Sup (≤) (λx. f x);
    mcont lub ord Sup (≤) (λx. g x) 
   mcont lub ord Sup (≤) (λx. f x * g x :: ennreal)"
by(best intro: ccpo.mcont2mcont'[OF complete_lattice_ccpo] mcont_times_ennreal1 mcont_times_ennreal2 ccpo.mcont_const[OF complete_lattice_ccpo])

lemma ereal_INF_cmult: "0 < c  (INF iI. c * f i) = ereal c * (INF iI. f i)"
using ereal_Inf_cmult[where P="λx. iI. x = f i", of c]
by(rule box_equals)(auto intro!: arg_cong[where f="Inf"] arg_cong2[where f="(*)"])

lemma ereal_INF_multc: "0 < c  (INF iI. f i * c) = (INF iI. f i) * ereal c"
using ereal_INF_cmult[of c f I] by(simp add: mult.commute)

lemma INF_mult_left_ennreal: 
  assumes "I = {}  c  0"
  and " c = ; iI. f i > 0   p>0. iI. f i  p"
  shows "c * (INF iI. f i) = (INF iI. c * f i ::ennreal)"
proof -
  consider (empty) "I = {}" | (top) "c = " | (zero) "c = 0" | (normal) "I  {}" "c  " "c  0" by auto
  then show ?thesis
  proof cases
    case empty then show ?thesis by(simp add: ennreal_mult_top assms(1))
  next
    case top
    show ?thesis
    proof(cases "iI. f i > 0")
      case True
      with assms(2) top obtain p where "p > 0" and p: "i. i  I  f i  p" by auto
      then have *: "i. i  I  f i > 0" by(auto intro: less_le_trans)
      note 0 < p also from p have "p  (INF iI. f i)" by(rule INF_greatest)
      finally show ?thesis using top by(auto simp add: ennreal_top_mult dest: *)
    next
      case False
      hence "f i = 0" if "i  I" for i using that by auto
      thus ?thesis using top by(simp add: INF_constant ennreal_mult_top)
    qed
  next
    case zero
    then show ?thesis using assms(1) by(auto simp add: INF_constant)
  next
    case normal
    then show ?thesis including ennreal.lifting
      apply transfer
      subgoal for I c f by(cases c)(simp_all add: top_ereal_def ereal_INF_cmult)
      done
  qed
qed

lemma pmf_map_spmf_None: "pmf (map_spmf f p) None = pmf p None"
by(simp add: pmf_None_eq_weight_spmf)

lemma nn_integral_try_spmf:
  "nn_integral (measure_spmf (try_spmf p q)) f = nn_integral (measure_spmf p) f + nn_integral (measure_spmf q) f * pmf p None"
by(simp add: nn_integral_measure_spmf spmf_try_spmf distrib_right nn_integral_add ennreal_mult mult.assoc nn_integral_cmult)
  (simp add: mult.commute)

lemma INF_UNION: "(INF z  xA. B x. f z) = (INF xA. INF zB x. f z)" for f :: "_  'b::complete_lattice"
by(auto intro!: antisym INF_greatest intro: INF_lower2)


definition nn_integral_spmf :: "'a spmf  ('a  ennreal)  ennreal" where
  "nn_integral_spmf p = nn_integral (measure_spmf p)"

lemma nn_integral_spmf_parametric [transfer_rule]:
  includes lifting_syntax
  shows "(rel_spmf A ===> (A ===> (=)) ===> (=)) nn_integral_spmf nn_integral_spmf"
  unfolding nn_integral_spmf_def
proof(rule rel_funI)+
  fix p q and f g :: "_  ennreal"
  assume pq: "rel_spmf A p q" and fg: "(A ===> (=)) f g"
  from pq obtain pq where pq [rule_format]: "(x, y)set_spmf pq. A x y"
    and p: "p = map_spmf fst pq" and q: "q = map_spmf snd pq"
    by(cases rule: rel_spmfE) auto
  show "nn_integral (measure_spmf p) f = nn_integral (measure_spmf q) g"
    by(simp add: p q)(auto simp add: nn_integral_measure_spmf spmf_eq_0_set_spmf dest!: pq rel_funD[OF fg] intro: ennreal_mult_left_cong intro!: nn_integral_cong)
qed

lemma weight_spmf_mcont2mcont [THEN lfp.mcont2mcont, cont_intro]:
  shows weight_spmf_mcont: "mcont (lub_spmf) (ord_spmf (=)) Sup (≤) (λp. ennreal (weight_spmf p))"
apply(simp add: mcont_def cont_def weight_spmf_def measure_spmf.emeasure_eq_measure[symmetric] emeasure_lub_spmf)
apply(rule call_mono[THEN lfp.mono2mono])
apply(unfold fun_ord_def)
apply(rule monotone_emeasure_spmf[unfolded le_fun_def])
done

lemma mono2mono_nn_integral_spmf [THEN lfp.mono2mono, cont_intro]:
  shows monotone_nn_integral_spmf: "monotone (ord_spmf (=)) (≤) (λp. integralN (measure_spmf p) f)"
by(rule monotoneI)(auto simp add: nn_integral_measure_spmf intro!: nn_integral_mono mult_right_mono dest: monotone_spmf[THEN monotoneD])

lemma cont_nn_integral_spmf:
  "cont lub_spmf (ord_spmf (=)) Sup (≤) (λp :: 'a spmf. nn_integral (measure_spmf p) f)"
proof
  fix Y :: "'a spmf set"
  assume Y: "Complete_Partial_Order.chain (ord_spmf (=)) Y" "Y  {}"
  let ?M = "count_space (set_spmf (lub_spmf Y))"
  have "nn_integral (measure_spmf (lub_spmf Y)) f = + x. ennreal (spmf (lub_spmf Y) x) * f x ?M"
    by(simp add: nn_integral_measure_spmf')
  also have " = + x. (SUP pY. ennreal (spmf p x) * f x) ?M"
    by(simp add: spmf_lub_spmf Y ennreal_SUP[OF SUP_spmf_neq_top'] SUP_mult_right_ennreal)
  also have " = (SUP pY. + x. ennreal (spmf p x) * f x ?M)"
  proof(rule nn_integral_monotone_convergence_SUP_countable)
    show "Complete_Partial_Order.chain (≤) ((λi x. ennreal (spmf i x) * f x) ` Y)"
      using Y(1) by(rule chain_imageI)(auto simp add: le_fun_def intro!: mult_right_mono dest: monotone_spmf[THEN monotoneD])
  qed(simp_all add: Y(2))
  also have " = (SUP pY. nn_integral (measure_spmf p) f)"
    by(auto simp add: nn_integral_measure_spmf Y nn_integral_count_space_indicator set_lub_spmf spmf_eq_0_set_spmf split: split_indicator intro!: SUP_cong nn_integral_cong)
  finally show "nn_integral (measure_spmf (lub_spmf Y)) f = (SUP pY. nn_integral (measure_spmf p) f)" .
qed

lemma mcont2mcont_nn_integral_spmf [THEN lfp.mcont2mcont, cont_intro]:
  shows mcont_nn_integral_spmf:
  "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp :: 'a spmf. nn_integral (measure_spmf p) f)"
by(rule mcontI)(simp_all add: cont_nn_integral_spmf)
 

lemma nn_integral_mono2mono:
  assumes "x. x  space M  monotone ord (≤) (λf. F f x)"
  shows "monotone ord (≤) (λf. nn_integral M (F f))"
  by(rule monotoneI nn_integral_mono monotoneD[OF assms])+

lemma nn_integral_mono_lfp [partial_function_mono]:
  ― ‹@{ML Partial_Function.mono_tac} does not like conditional assumptions (more precisely the case splitter)›
  "(x. lfp.mono_body (λf. F f x))  lfp.mono_body (λf. nn_integral M (F f))"
  by(rule nn_integral_mono2mono)

lemma INF_mono_lfp [partial_function_mono]:
  "(x. lfp.mono_body (λf. F f x))  lfp.mono_body (λf. INF xM. F f x)"
  by(rule monotoneI)(blast dest: monotoneD intro: INF_mono)

lemmas parallel_fixp_induct_1_2 = parallel_fixp_induct_uc[
  of _ _ _ _ "λx. x" _ "λx. x" "case_prod" _ "curry",
  where P="λf g. P f (curry g)",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl]
  for P

lemma monotone_ennreal_add1: "monotone (≤) (≤) (λx. x + y :: ennreal)"
by(auto intro!: monotoneI)

lemma monotone_ennreal_add2: "monotone (≤) (≤) (λy. x + y :: ennreal)"
by(auto intro!: monotoneI)

lemma mono2mono_ennreal_add[THEN lfp.mono2mono2, cont_intro, simp]:
  shows monotone_eadd: "monotone (rel_prod (≤) (≤)) (≤) (λ(x, y). x + y :: ennreal)"
by(simp add: monotone_ennreal_add1 monotone_ennreal_add2)

lemma ennreal_add_partial_function_mono [partial_function_mono]:
  " monotone (fun_ord (≤)) (≤) f; monotone (fun_ord (≤)) (≤) g 
   monotone (fun_ord (≤)) (≤) (λx. f x + g x :: ennreal)"
by(rule mono2mono_ennreal_add)

context
  fixes fail :: ennreal
  and  :: "('out, 'ret) ℐ"
  and f :: "'a  ennreal"
  notes [[function_internals]]
begin

partial_function (lfp_strong) expectation_gpv :: "('a, 'out, 'ret) gpv  ennreal" where
  "expectation_gpv gpv = 
  (+ generat. (case generat of Pure x  f x 
              | IO out c  INF rresponses_ℐ  out. expectation_gpv (c r)) measure_spmf (the_gpv gpv))
   + fail * pmf (the_gpv gpv) None"

lemma expectation_gpv_fixp_induct [case_names adm bottom step]:
  assumes "lfp.admissible P"
    and "P (λ_. 0)"
    and "expectation_gpv'.  gpv. expectation_gpv' gpv  expectation_gpv gpv; P expectation_gpv'  
         P (λgpv. (+ generat. (case generat of Pure x  f x | IO out c  INF rresponses_ℐ  out. expectation_gpv' (c r)) measure_spmf (the_gpv gpv)) + fail * pmf (the_gpv gpv) None)"
  shows "P expectation_gpv"
  by(rule expectation_gpv.fixp_induct)(simp_all add: bot_ennreal_def assms fun_ord_def)
  
lemma expectation_gpv_Done [simp]: "expectation_gpv (Done x) = f x"
  by(subst expectation_gpv.simps)(simp add: measure_spmf_return_spmf nn_integral_return)

lemma expectation_gpv_Fail [simp]: "expectation_gpv Fail = fail"
  by(subst expectation_gpv.simps) simp

lemma expectation_gpv_lift_spmf [simp]: 
  "expectation_gpv (lift_spmf p) = (+ x. f x measure_spmf p) + fail * pmf p None"
  by(subst expectation_gpv.simps)(auto simp add: o_def pmf_map vimage_def measure_pmf_single)

lemma expectation_gpv_Pause [simp]:
  "expectation_gpv (Pause out c) = (INF rresponses_ℐ  out. expectation_gpv (c r))"
  by(subst expectation_gpv.simps)(simp add: measure_spmf_return_spmf nn_integral_return)

end

context begin
private definition "weight_spmf' p = weight_spmf p"
lemmas weight_spmf'_parametric = weight_spmf_parametric[folded weight_spmf'_def]
lemma expectation_gpv_parametric':
  includes lifting_syntax notes weight_spmf'_parametric[transfer_rule]
  shows "((=) ===> rel_ℐ C R ===> (A ===> (=)) ===> rel_gpv'' A C R ===> (=)) expectation_gpv expectation_gpv"
  unfolding expectation_gpv_def
  apply(rule rel_funI)
  apply(rule rel_funI)
  apply(rule rel_funI)
  apply(rule fixp_lfp_parametric_eq[OF expectation_gpv.mono expectation_gpv.mono])
  apply(fold nn_integral_spmf_def Set.is_empty_def pmf_None_eq_weight_spmf[symmetric])
  apply(simp only: weight_spmf'_def[symmetric])
  subgoal premises [transfer_rule] supply the_gpv_parametric'[transfer_rule] by transfer_prover
  done
end

lemma expectation_gpv_parametric [transfer_rule]:
  includes lifting_syntax
  shows "((=) ===> rel_ℐ C (=) ===> (A ===> (=)) ===> rel_gpv A C ===> (=)) expectation_gpv expectation_gpv"
using expectation_gpv_parametric'[of C "(=)" A] by(simp add: rel_gpv_conv_rel_gpv'')

lemma expectation_gpv_cong:
  fixes fail fail'
  assumes fail: "fail = fail'"
  and: " = ℐ'"
  and gpv: "gpv = gpv'"
  and f: "x. x  results_gpv ℐ' gpv'  f x = g x"
  shows "expectation_gpv fail  f gpv = expectation_gpv fail' ℐ' g gpv'"
using f unfolding[symmetric] gpv[symmetric] fail[symmetric]
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv' expectation_gpv'') show ?case
    by(rule arg_cong2[where f="(+)"] nn_integral_cong_AE)+(clarsimp simp add: step.prems results_gpv.intros split!: generat.split intro!: INF_cong[OF refl] step.IH)+
qed

lemma expectation_gpv_cong_fail:
  "colossless_gpv  gpv  expectation_gpv fail  f gpv = expectation_gpv fail'  f gpv" for fail
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv' expectation_gpv'')
  from colossless_gpv_lossless_spmfD[OF step.prems] show ?case
    by(auto simp add: lossless_iff_pmf_None intro!: nn_integral_cong_AE INF_cong step.IH intro: colossless_gpv_continuationD[OF step.prems] split: generat.split)
qed

lemma expectation_gpv_mono:
  fixes fail fail'
  assumes fail: "fail  fail'"
  and fg: "f  g"
  shows "expectation_gpv fail  f gpv  expectation_gpv fail'  g gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv' expectation_gpv'')
  show ?case
    by(intro add_mono mult_right_mono fail nn_integral_mono_AE)
      (auto split: generat.split simp add: fg[THEN le_funD] INF_mono rev_bexI step.IH)
qed

lemma expectation_gpv_mono_strong:
  fixes fail fail'
  assumes fail: "¬ colossless_gpv  gpv  fail  fail'"
  and fg: "x. x  results_gpv  gpv  f x  g x"
  shows "expectation_gpv fail  f gpv  expectation_gpv fail'  g gpv"
proof -
  let ?fail = "if colossless_gpv  gpv then fail' else fail"
    and ?f = "λx. if x  results_gpv  gpv then f x else g x"
  have "expectation_gpv fail  f gpv = expectation_gpv ?fail  f gpv" by(simp cong: expectation_gpv_cong_fail)
  also have " = expectation_gpv ?fail  ?f gpv" by(rule expectation_gpv_cong; simp)
  also have "  expectation_gpv fail'  g gpv" using assms by(simp add: expectation_gpv_mono le_fun_def)
  finally show ?thesis .
qed

lemma expectation_gpv_bind [simp]:
  fixes  f g fail
  defines "expectation_gpv1  expectation_gpv fail  f"
  and "expectation_gpv2  expectation_gpv fail  (expectation_gpv fail  f  g)"
  shows "expectation_gpv1 (bind_gpv gpv g) = expectation_gpv2 gpv" (is "?lhs = ?rhs")
proof(rule antisym)
  note [simp] = case_map_generat o_def
    and [cong del] = generat.case_cong_weak
  show "?lhs  ?rhs" unfolding expectation_gpv1_def
  proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step expectation_gpv')
    show ?case unfolding expectation_gpv2_def
      apply(rewrite bind_gpv.sel)
      apply(simp add: map_spmf_bind_spmf measure_spmf_bind)
      apply(rewrite nn_integral_bind[where B="measure_spmf _"])
        apply(simp_all add: space_subprob_algebra)
      apply(rewrite expectation_gpv.simps)
      apply(simp add: pmf_bind_spmf_None distrib_left nn_integral_eq_integral[symmetric] measure_spmf.integrable_const_bound[where B=1] pmf_le_1 nn_integral_cmult[symmetric] nn_integral_add[symmetric])
      apply(rule disjI2)
      apply(rule nn_integral_mono)
      apply(clarsimp split!: generat.split)
       apply(rewrite expectation_gpv.simps)
       apply simp
       apply(rule disjI2)
       apply(rule nn_integral_mono)
       apply(clarsimp split: generat.split)
       apply(rule INF_mono)
       apply(erule rev_bexI)
       apply(rule step.hyps)
      apply(clarsimp simp add: measure_spmf_return_spmf nn_integral_return)
      apply(rule INF_mono)
      apply(erule rev_bexI)
      apply(rule step.IH[unfolded expectation_gpv2_def o_def])
      done
  qed
  show "?rhs  ?lhs" unfolding expectation_gpv2_def
  proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step expectation_gpv')
    show ?case unfolding expectation_gpv1_def
      apply(rewrite in "_  " expectation_gpv.simps)
      apply(rewrite bind_gpv.sel)
      apply(simp add: measure_spmf_bind)
      apply(rewrite nn_integral_bind[where B="measure_spmf _"])
        apply(simp_all add: space_subprob_algebra)
      apply(simp add: pmf_bind_spmf_None distrib_left nn_integral_eq_integral[symmetric] measure_spmf.integrable_const_bound[where B=1] pmf_le_1 nn_integral_cmult[symmetric] nn_integral_add[symmetric])
      apply(rule disjI2)
      apply(rule nn_integral_mono)
      apply(clarsimp split!: generat.split)
       apply(rewrite expectation_gpv.simps)
       apply(simp cong del: if_weak_cong add: generat.map_comp id_def[symmetric] generat.map_id)
      apply(simp add: measure_spmf_return_spmf nn_integral_return)
      apply(rule INF_mono)
      apply(erule rev_bexI)
      apply(rule step.IH[unfolded expectation_gpv1_def])
      done
  qed
qed

lemma expectation_gpv_try_gpv [simp]:
  fixes fail  f gpv'
  defines "expectation_gpv1  expectation_gpv fail  f"
    and "expectation_gpv2  expectation_gpv (expectation_gpv fail  f gpv')  f"
  shows "expectation_gpv1 (try_gpv gpv gpv') = expectation_gpv2 gpv"
proof(rule antisym)
  show "expectation_gpv1 (try_gpv gpv gpv')  expectation_gpv2 gpv" unfolding expectation_gpv1_def
  proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case step [unfolded expectation_gpv2_def]: (step expectation_gpv')
    show ?case unfolding expectation_gpv2_def
      apply(rewrite expectation_gpv.simps)
      apply(rewrite in "_  _ + " expectation_gpv.simps)
      apply(simp add: pmf_map_spmf_None nn_integral_try_spmf o_def generat.map_comp case_map_generat distrib_right cong del: generat.case_cong_weak)
      apply(simp add: mult_ac add.assoc ennreal_mult)
      apply(intro disjI2 add_mono mult_left_mono nn_integral_mono; clarsimp split: generat.split intro!: INF_mono step elim!: rev_bexI)
      done
  qed
  show "expectation_gpv2 gpv  expectation_gpv1 (try_gpv gpv gpv')" unfolding expectation_gpv2_def
  proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case step [unfolded expectation_gpv1_def]: (step expectation_gpv')
    show ?case unfolding expectation_gpv1_def
      apply(rewrite in "_  " expectation_gpv.simps)
      apply(rewrite in "  _" expectation_gpv.simps)
      apply(simp add: pmf_map_spmf_None nn_integral_try_spmf o_def generat.map_comp case_map_generat distrib_left ennreal_mult mult_ac id_def[symmetric] generat.map_id cong del: generat.case_cong_weak)
      apply(rule disjI2 nn_integral_mono)+
      apply(clarsimp split: generat.split intro!: INF_mono step(2) elim!: rev_bexI)
      done
  qed
qed

lemma expectation_gpv_restrict_gpv:
  " ⊢g gpv   expectation_gpv fail  f (restrict_gpv  gpv) = expectation_gpv fail  f gpv" for fail
proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv'')
  show ?case
    apply(simp add: pmf_map vimage_def)
    apply(rule arg_cong2[where f="(+)"])
    subgoal by(clarsimp simp add: measure_spmf_def nn_integral_distr nn_integral_restrict_space step.IH WT_gpv_ContD[OF step.prems] AE_measure_pmf_iff in_set_spmf[symmetric] WT_gpv_OutD[OF step.prems] split!: option.split generat.split intro!: nn_integral_cong_AE INF_cong[OF refl])
    apply(simp add: measure_pmf_single[symmetric])
    apply(rule arg_cong[where f="λx. _ * ennreal x"])
    apply(rule measure_pmf.finite_measure_eq_AE)
    apply(auto simp add: AE_measure_pmf_iff in_set_spmf[symmetric] intro: WT_gpv_OutD[OF step.prems] split: option.split_asm generat.split_asm if_split_asm)
    done
qed

lemma expectation_gpv_const_le: " ⊢g gpv   expectation_gpv fail  (λ_. c) gpv  max c fail" for fail
proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv')
  have "integralN (measure_spmf (the_gpv gpv)) (case_generat (λx. c) (λout c. INF rresponses_ℐ  out. expectation_gpv' (c r)))  integralN (measure_spmf (the_gpv gpv)) (λ_. max c fail)"
    using step.prems
    by(intro nn_integral_mono_AE)(auto 4 4 split: generat.split intro: INF_lower2 step.IH WT_gpv_ContD[OF step.prems] dest!: WT_gpv_OutD simp add: in_outs_ℐ_iff_responses_ℐ)
  also have " + fail * pmf (the_gpv gpv) None   + max c fail * pmf (the_gpv gpv) None"
    by(intro add_left_mono mult_right_mono) simp_all
  also have "  max c fail"
    by(simp add: measure_spmf.emeasure_eq_measure pmf_None_eq_weight_spmf ennreal_minus[symmetric])
      (metis (no_types, hide_lams) add_diff_eq_iff_ennreal distrib_left ennreal_le_1 le_max_iff_disj max.cobounded2 mult.commute mult.left_neutral weight_spmf_le_1)
  finally show ?case by(simp add: add_mono)
qed

lemma expectation_gpv_no_results:
   " results_gpv  gpv = {};  ⊢g gpv    expectation_gpv 0  f gpv = 0"
proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv')
  have "results_gpv  (c x) = {}" if "IO out c  set_spmf (the_gpv gpv)" "x  responses_ℐ  out"
    for out c x using that step.prems(1) by(auto intro: results_gpv.IO)
  then show ?case using step.prems
    by(auto 4 4 intro!: nn_integral_zero' split: generat.split intro: results_gpv.Pure cong: INF_cong simp add: step.IH WT_gpv_ContD INF_constant in_outs_ℐ_iff_responses_ℐ dest: WT_gpv_OutD)
qed

lemma expectation_gpv_cmult:
  fixes fail
  assumes "0 < c" and "c  "
  shows "c * expectation_gpv fail  f gpv = expectation_gpv (c * fail)  (λx. c * f x) gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by(simp add: bot_ennreal_def)
  case (step expectation_gpv' expectation_gpv'')
  show ?case using assms
    apply(simp add: distrib_left mult_ac nn_integral_cmult[symmetric] generat.case_distrib[where h="(*) _"])
    apply(subst INF_mult_left_ennreal, simp_all add: step.IH)
    done
qed

lemma expectation_gpv_le_exec_gpv:
  assumes callee: "s x. x  outs_ℐ   lossless_spmf (callee s x)"
    and WT_gpv: " ⊢g gpv "
    and WT_callee: "s.  ⊢c callee s "
  shows "expectation_gpv 0  f gpv  + (x, s). f x measure_spmf (exec_gpv callee gpv s)"
using WT_gpv
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_1_2[OF complete_lattice_partial_function_definitions partial_function_definitions_spmf expectation_gpv.mono exec_gpv.mono expectation_gpv_def exec_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by(simp add: bot_ennreal_def)
  case (step expectation_gpv'' exec_gpv')
  have *: "(INF rresponses_ℐ  out. expectation_gpv'' (c r))  + (x, s). f x measure_spmf (bind_spmf (callee s out) (λ(r, s'). exec_gpv' (c r) s'))" (is "?lhs  ?rhs")
    if "IO out c  set_spmf (the_gpv gpv)" for out c 
  proof -
    from step.prems that have out: "out  outs_ℐ " by(rule WT_gpvD)
    have "?lhs = + _. ?lhs measure_spmf (callee s out)" using callee[OF out, THEN lossless_weight_spmfD]
      by(simp add: measure_spmf.emeasure_eq_measure)
    also have "  + (r, s'). expectation_gpv'' (c r) measure_spmf (callee s out)"
      by(rule nn_integral_mono_AE)(auto intro: WT_calleeD[OF WT_callee _ out] INF_lower)
    also have "  + (r, s'). + (x, _). f x measure_spmf (exec_gpv' (c r) s') measure_spmf (callee s out)"
      by(rule nn_integral_mono_AE)(auto intro!: step.IH intro: WT_gpv_ContD[OF step.prems that] WT_calleeD[OF WT_callee _ out])
    also have " = ?rhs" by(simp add: measure_spmf_bind split_def nn_integral_bind[where B="measure_spmf _"] o_def space_subprob_algebra)
    finally show ?thesis .
  qed
  show ?case
    by(simp add: measure_spmf_bind nn_integral_bind[where B="measure_spmf _"] space_subprob_algebra)
      (simp split!: generat.split add: measure_spmf_return_spmf nn_integral_return * nn_integral_mono_AE)
qed

definition weight_gpv :: "('out, 'ret) ('a, 'out, 'ret) gpv  real"
  where "weight_gpv  gpv = enn2real (expectation_gpv 0  (λ_. 1) gpv)"

lemma weight_gpv_Done [simp]: "weight_gpv  (Done x) = 1"
by(simp add: weight_gpv_def)

lemma weight_gpv_Fail [simp]: "weight_gpv  Fail = 0"
by(simp add: weight_gpv_def)

lemma weight_gpv_lift_spmf [simp]: "weight_gpv  (lift_spmf p) = weight_spmf p"
by(simp add: weight_gpv_def measure_spmf.emeasure_eq_measure)

lemma weight_gpv_Pause [simp]:
  "(r. r  responses_ℐ  out   ⊢g c r )
    weight_gpv  (Pause out c) = (if out  outs_ℐ  then INF rresponses_ℐ  out. weight_gpv  (c r) else 0)"
apply(clarsimp simp add: weight_gpv_def in_outs_ℐ_iff_responses_ℐ)
apply(erule enn2real_INF)
apply(clarsimp simp add: expectation_gpv_const_le[THEN le_less_trans])
done

lemma weight_gpv_nonneg: "0  weight_gpv  gpv"
by(simp add: weight_gpv_def)

lemma weight_gpv_le_1: " ⊢g gpv   weight_gpv  gpv  1"
using expectation_gpv_const_le[of  gpv 0 1] by(simp add: weight_gpv_def enn2real_leI max_def)

theorem weight_exec_gpv:
  assumes callee: "s x. x  outs_ℐ   lossless_spmf (callee s x)"
    and WT_gpv: " ⊢g gpv "
    and WT_callee: "s.  ⊢c callee s "
  shows "weight_gpv  gpv  weight_spmf (exec_gpv callee gpv s)"
proof -
  have "expectation_gpv 0  (λ_. 1) gpv  + (x, s). 1 measure_spmf (exec_gpv callee gpv s)"
    using assms by(rule expectation_gpv_le_exec_gpv)
  also have " = weight_spmf (exec_gpv callee gpv s)"
    by(simp add: split_def measure_spmf.emeasure_eq_measure)
  finally show ?thesis by(simp add: weight_gpv_def enn2real_leI)
qed

lemma (in callee_invariant_on) weight_exec_gpv:
  assumes callee: "s x.  x  outs_ℐ ; I s   lossless_spmf (callee s x)"
  and WT_gpv: " ⊢g gpv "
  and I: "I s"
  shows "weight_gpv  gpv  weight_spmf (exec_gpv callee gpv s)"
including lifting_syntax
proof -
  { assume "(Rep :: 's'  's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s'  's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr  λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp
    
    let ?C = "eq_onp (λx. x  outs_ℐ )"

    define callee' where "callee'  (Rep ---> id ---> map_spmf (map_prod id Abs)) callee"
    have [transfer_rule]: "(cr ===> ?C ===> rel_spmf (rel_prod (=) cr)) callee callee'"
      by(auto simp add: callee'_def rel_fun_def cr_def spmf_rel_map prod.rel_map td.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td.Rep[simplified] dest: callee_invariant)
    define s' where "s'  Abs s"
    have [transfer_rule]: "cr s s'" using I by(simp add: cr_def s'_def td.Abs_inverse)

    have [transfer_rule]: "rel_ℐ ?C (=)  "
      by(rule rel_ℐI)(auto simp add: rel_set_eq set_relator_eq_onp eq_onp_same_args dest: eq_onp_to_eq)
    note [transfer_rule] = bi_unique_eq_onp bi_unique_eq

    define gpv' where "gpv'  restrict_gpv  gpv"
    have [transfer_rule]: "rel_gpv (=) ?C gpv' gpv'"
      by(fold eq_onp_top_eq_eq)(auto simp add: gpv.rel_eq_onp eq_onp_same_args pred_gpv_def gpv'_def dest: in_outs'_restrict_gpvD)

    define weight_spmf' :: "('c × 's') spmf  real" where "weight_spmf'  weight_spmf"
    define weight_spmf'' :: "('c × 's) spmf  real" where "weight_spmf''  weight_spmf"
    have [transfer_rule]: "(rel_spmf (rel_prod (=) cr) ===> (=)) weight_spmf'' weight_spmf'"
      by(simp add: weight_spmf'_def weight_spmf''_def weight_spmf_parametric)

    have [rule_format]: "s. x  outs_ℐ . lossless_spmf (callee' s x)"
      by(transfer)(blast intro: callee)
    moreover have " ⊢g gpv' " by(simp add: gpv'_def)
    moreover have "s.  ⊢c callee' s " by transfer(rule WT_callee)
    ultimately have **: "weight_gpv  gpv'  weight_spmf' (exec_gpv callee' gpv' s')"
      unfolding weight_spmf'_def by(rule weight_exec_gpv)
    have [transfer_rule]: "((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) callee callee"
      by(simp add: rel_fun_def eq_onp_def prod.rel_eq)
    have "weight_gpv  gpv'  weight_spmf'' (exec_gpv callee gpv' s)" using ** by transfer
    also have "exec_gpv callee gpv' s = exec_gpv callee gpv s"
      unfolding gpv'_def using WT_gpv I by(rule exec_gpv_restrict_gpv_invariant)
    also have "weight_gpv  gpv' = weight_gpv  gpv" using WT_gpv 
      by(simp add: gpv'_def expectation_gpv_restrict_gpv weight_gpv_def)
    finally have ?thesis by(simp add: weight_spmf''_def) }
  from this[cancel_type_definition] I show ?thesis by blast
qed

subsection ‹Probabilistic termination›

definition pgen_lossless_gpv :: "ennreal  ('c, 'r) ('a, 'c, 'r) gpv  bool"
where "pgen_lossless_gpv fail  gpv = (expectation_gpv fail  (λ_. 1) gpv = 1)" for fail

abbreviation plossless_gpv :: "('c, 'r) ('a, 'c, 'r) gpv  bool"
where "plossless_gpv  pgen_lossless_gpv 0"

abbreviation pfinite_gpv :: "('c, 'r) ('a, 'c, 'r) gpv  bool"
where "pfinite_gpv  pgen_lossless_gpv 1"

lemma pgen_lossless_gpvI [intro?]: "expectation_gpv fail  (λ_. 1) gpv = 1  pgen_lossless_gpv fail  gpv" for fail
by(simp add: pgen_lossless_gpv_def)

lemma pgen_lossless_gpvD: "pgen_lossless_gpv fail  gpv  expectation_gpv fail  (λ_. 1) gpv = 1" for fail
by(simp add: pgen_lossless_gpv_def)

lemma lossless_imp_plossless_gpv:
  assumes "lossless_gpv  gpv" " ⊢g gpv "
  shows "plossless_gpv  gpv"
proof
  show "expectation_gpv 0  (λ_. 1) gpv = 1" using assms
  proof(induction rule: lossless_WT_gpv_induct)
    case (lossless_gpv p)
    have "expectation_gpv 0  (λ_. 1) (GPV p) = nn_integral (measure_spmf p) (case_generat (λ_. 1) (λout c. INF rresponses_ℐ  out. 1))"
      by(subst expectation_gpv.simps)(clarsimp split: generat.split cong: INF_cong simp add: lossless_gpv.IH intro!: nn_integral_cong_AE)
    also have " = nn_integral (measure_spmf p) (λ_. 1)"
      by(intro nn_integral_cong_AE)(auto split: generat.split dest!: lossless_gpv.hyps(2) simp add: in_outs_ℐ_iff_responses_ℐ)
    finally show ?case by(simp add: measure_spmf.emeasure_eq_measure lossless_weight_spmfD lossless_gpv.hyps(1))
  qed
qed

lemma finite_imp_pfinite_gpv:
  assumes "finite_gpv  gpv" " ⊢g gpv "
  shows "pfinite_gpv  gpv"
proof
  show "expectation_gpv 1  (λ_. 1) gpv = 1" using assms
  proof(induction rule: finite_gpv_induct)
    case (finite_gpv gpv)
    then have "expectation_gpv 1  (λ_. 1) gpv = nn_integral (measure_spmf (the_gpv gpv)) (case_generat (λ_. 1) (λout c. INF rresponses_ℐ  out. 1)) + pmf (the_gpv gpv) None"
      by(subst expectation_gpv.simps)(clarsimp intro!: nn_integral_cong_AE INF_cong[OF refl] split!: generat.split simp add: WT_gpv_ContD)
    also have " = nn_integral (measure_spmf (the_gpv gpv)) (λ_. 1) + pmf (the_gpv gpv) None"
      by(intro arg_cong2[where f="(+)"] nn_integral_cong_AE)
        (auto split: generat.split dest!: WT_gpv_OutD[OF finite_gpv.prems] simp add: in_outs_ℐ_iff_responses_ℐ)
    finally show ?case
      by(simp add: measure_spmf.emeasure_eq_measure ennreal_plus[symmetric] del: ennreal_plus)
        (simp add: pmf_None_eq_weight_spmf)
  qed
qed

lemma plossless_gpv_lossless_spmfD:
  assumes lossless: "plossless_gpv  gpv"
  and WT: " ⊢g gpv "
  shows "lossless_spmf (the_gpv gpv)"
proof -
  have "1 = expectation_gpv 0  (λ_. 1) gpv"
    using lossless by(auto dest: pgen_lossless_gpvD simp add: weight_gpv_def)
  also have " = + generat. (case generat of Pure x  1 | IO out c  INF rresponses_ℐ  out. expectation_gpv 0  (λ_. 1) (c r)) measure_spmf (the_gpv gpv)"
    by(subst expectation_gpv.simps)(auto)
  also have "  + generat. (case generat of Pure x  1 | IO out c  1) measure_spmf (the_gpv gpv)"
    apply(rule nn_integral_mono_AE)
    apply(clarsimp split: generat.split)
    apply(frule WT_gpv_OutD[OF WT])
    using expectation_gpv_const_le[of  _ 0 1]
    apply(auto simp add: in_outs_ℐ_iff_responses_ℐ max_def intro: INF_lower2 WT_gpv_ContD[OF WT] dest: WT_gpv_OutD[OF WT])
    done
  also have " = weight_spmf (the_gpv gpv)"
    by(auto simp add: weight_spmf_eq_nn_integral_spmf nn_integral_measure_spmf intro!: nn_integral_cong split: generat.split)
  finally show ?thesis using weight_spmf_le_1[of "the_gpv gpv"] by(simp add: lossless_spmf_def)
qed

lemma
  shows plossless_gpv_ContD:
  " plossless_gpv  gpv; IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out;  ⊢g gpv  
   plossless_gpv  (c input)"
  and pfinite_gpv_ContD:
  " pfinite_gpv  gpv; IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out;  ⊢g gpv  
   pfinite_gpv  (c input)"
proof(rule_tac [!] pgen_lossless_gpvI, rule_tac [!] antisym[rotated], rule_tac ccontr, rule_tac [3] ccontr)
  assume IO: "IO out c  set_spmf (the_gpv gpv)"
    and input: "input  responses_ℐ  out"
    and WT: " ⊢g gpv "
  from WT IO input have WT': " ⊢g c input " by(rule WT_gpv_ContD)
  from expectation_gpv_const_le[OF this, of 0 1] expectation_gpv_const_le[OF this, of 1 1]
  show "expectation_gpv 0  (λ_. 1) (c input)  1"
    and "expectation_gpv 1  (λ_. 1) (c input)  1" by(simp_all add: max_def)

  have less: "expectation_gpv fail  (λ_. 1) gpv < weight_spmf (the_gpv gpv) + fail * pmf (the_gpv gpv) None"
    if fail: "fail  1" and *: "¬ 1  expectation_gpv fail  (λ_. 1) (c input)" for fail :: ennreal
  proof -
    have "expectation_gpv fail  (λ_. 1) gpv = (+ generat. (case generat of Pure x  1 | IO out c  INF rresponses_ℐ  out. expectation_gpv fail  (λ_. 1) (c r)) * spmf (the_gpv gpv) generat * indicator (UNIV - {IO out c}) generat + (INF rresponses_ℐ  out. expectation_gpv fail  (λ_. 1) (c r)) * spmf (the_gpv gpv) (IO out c) * indicator {IO out c} generat count_space UNIV) + fail * pmf (the_gpv gpv) None"
      by(subst expectation_gpv.simps)(auto simp add: nn_integral_measure_spmf mult.commute intro!: nn_integral_cong split: split_indicator generat.split)
    also have " = (+ generat. (case generat of Pure x  1 | IO out c  INF rresponses_ℐ  out. expectation_gpv fail  (λ_. 1) (c r)) * spmf (the_gpv gpv) generat * indicator (UNIV - {IO out c}) generat count_space UNIV) +
      (INF rresponses_ℐ  out. expectation_gpv fail  (λ_. 1) (c r)) * spmf (the_gpv gpv) (IO out c) + fail * pmf (the_gpv gpv) None" (is "_ = ?rest + ?cr + _")
      by(subst nn_integral_add) simp_all
    also from calculation expectation_gpv_const_le[OF WT, of fail 1] fail have fin: "?rest  "
      by(auto simp add: top_add top_unique max_def split: if_split_asm)
    have "?cr  expectation_gpv fail  (λ_. 1) (c input) * spmf (the_gpv gpv) (IO out c)"
      by(rule mult_right_mono INF_lower[OF input])+ simp
    also have "?rest +  < ?rest + 1 * ennreal (spmf (the_gpv gpv) (IO out c))"
      unfolding ennreal_add_left_cancel_less using * IO
      by(intro conjI fin ennreal_mult_strict_right_mono)(simp_all add: not_le weight_gpv_def in_set_spmf_iff_spmf)
    also have "?rest  + generat. spmf (the_gpv gpv) generat * indicator (UNIV - {IO out c}) generat count_space UNIV"
      apply(rule nn_integral_mono)
      apply(clarsimp split: generat.split split_indicator)
      apply(rule ennreal_mult_le_self2I)
      apply simp
      subgoal premises prems for out' c'
        apply(subgoal_tac "IO out' c'  set_spmf (the_gpv gpv)")
         apply(frule WT_gpv_OutD[OF WT])
         apply(simp add: in_outs_ℐ_iff_responses_ℐ)
         apply safe
         apply(erule notE)
         apply(rule INF_lower2, assumption)
         apply(rule expectation_gpv_const_le[THEN order_trans])
          apply(erule (1) WT_gpv_ContD[OF WT])
         apply(simp add: fail)
        using prems by(simp add: in_set_spmf_iff_spmf)
      done
    also have " + 1 * ennreal (spmf (the_gpv gpv) (IO out c)) = 
      (+ generat. spmf (the_gpv gpv) generat * indicator (UNIV - {IO out c}) generat + ennreal (spmf (the_gpv gpv) (IO out c)) * indicator {IO out c} generat count_space UNIV)"
      by(subst nn_integral_add)(simp_all)
    also have " = + generat. spmf (the_gpv gpv) generat count_space UNIV" 
      by(auto intro!: nn_integral_cong split: split_indicator)
    also have " = weight_spmf (the_gpv gpv)" by(simp add: nn_integral_spmf measure_spmf.emeasure_eq_measure space_measure_spmf)
    finally show ?thesis using fail
      by(fastforce simp add: top_unique add_mono ennreal_plus[symmetric] ennreal_mult_eq_top_iff)
  qed
  
  show False if *: "¬ 1  expectation_gpv 0  (λ_. 1) (c input)" and lossless: "plossless_gpv  gpv"
    using less[OF _ *] plossless_gpv_lossless_spmfD[OF lossless WT] lossless[THEN pgen_lossless_gpvD]
    by(simp add: lossless_spmf_def)

  show False if *: "¬ 1  expectation_gpv 1  (λ_. 1) (c input)" and finite: "pfinite_gpv  gpv"
    using less[OF _ *] finite[THEN pgen_lossless_gpvD] by(simp add: ennreal_plus[symmetric] del: ennreal_plus)(simp add: pmf_None_eq_weight_spmf)
qed

lemma plossless_iff_colossless_pfinite:
  assumes WT: " ⊢g gpv "
  shows "plossless_gpv  gpv  colossless_gpv  gpv  pfinite_gpv  gpv"
proof(intro iffI conjI; (elim conjE)?)
  assume *: "plossless_gpv  gpv"
  show "colossless_gpv  gpv" using * WT
  proof(coinduction arbitrary: gpv)
    case (colossless_gpv gpv)
    have ?lossless_spmf using colossless_gpv by(rule plossless_gpv_lossless_spmfD)
    moreover have ?continuation using colossless_gpv
      by(auto intro: plossless_gpv_ContD WT_gpv_ContD)
    ultimately show ?case ..
  qed

  show "pfinite_gpv  gpv" unfolding pgen_lossless_gpv_def
  proof(rule antisym)
    from expectation_gpv_const_le[OF WT, of 1 1] show "expectation_gpv 1  (λ_. 1) gpv  1" by simp
    have "1 = expectation_gpv 0  (λ_. 1) gpv" using * by(simp add: pgen_lossless_gpv_def)
    also have "  expectation_gpv 1  (λ_. 1) gpv" by(rule expectation_gpv_mono) simp_all
    finally show "1  " .
  qed
next
  show "plossless_gpv  gpv" if "colossless_gpv  gpv" and "pfinite_gpv  gpv" using that
    by(simp add: pgen_lossless_gpv_def cong: expectation_gpv_cong_fail)
qed

lemma pgen_lossless_gpv_Done [simp]: "pgen_lossless_gpv fail  (Done x)" for fail
by(simp add: pgen_lossless_gpv_def)

lemma pgen_lossless_gpv_Fail [simp]: "pgen_lossless_gpv fail  Fail  fail = 1" for fail
by(simp add: pgen_lossless_gpv_def)

lemma pgen_lossless_gpv_PauseI [simp, intro!]: 
  " out  outs_ℐ ; r. r  responses_ℐ  out  pgen_lossless_gpv fail  (c r) 
    pgen_lossless_gpv fail  (Pause out c)" for fail
by(simp add: pgen_lossless_gpv_def weight_gpv_def in_outs_ℐ_iff_responses_ℐ)

lemma pgen_lossless_gpv_bindI [simp, intro!]:
  " pgen_lossless_gpv fail  gpv; x. x  results_gpv  gpv  pgen_lossless_gpv fail  (f x) 
   pgen_lossless_gpv fail  (bind_gpv gpv f)" for fail
by(simp add: pgen_lossless_gpv_def weight_gpv_def o_def cong: expectation_gpv_cong)

lemma pgen_lossless_gpv_lift_spmf [simp]: 
  "pgen_lossless_gpv fail  (lift_spmf p)  lossless_spmf p  fail = 1" for fail
apply(cases fail)
subgoal
  by(simp add: pgen_lossless_gpv_def lossless_spmf_def measure_spmf.emeasure_eq_measure pmf_None_eq_weight_spmf ennreal_minus ennreal_mult[symmetric] weight_spmf_le_1 ennreal_plus[symmetric] del: ennreal_plus)
    (metis add_diff_cancel_left' diff_add_cancel eq_iff_diff_eq_0 mult_cancel_right1)
subgoal by(simp add: pgen_lossless_gpv_def measure_spmf.emeasure_eq_measure ennreal_top_mult lossless_spmf_def add_top weight_spmf_conv_pmf_None)
done

lemma expectation_gpv_top_pfinite:
  assumes "pfinite_gpv  gpv"
  shows "expectation_gpv   (λ_. ) gpv = "
proof(rule ccontr)
  assume *: "¬ ?thesis"
  have "1 = expectation_gpv 1  (λ_. 1) gpv" using assms by(simp add: pgen_lossless_gpv_def)
  also have "  expectation_gpv   (λ_. ) gpv" by(rule expectation_gpv_mono)(simp_all add: le_fun_def)
  also have " = 0"  using expectation_gpv_cmult[of "2"   "λ_. " gpv] *
    by(simp add: ennreal_mult_top) (metis ennreal_mult_cancel_left mult.commute mult_numeral_1_right not_gr_zero numeral_eq_one_iff semiring_norm(85) zero_neq_numeral)
  finally show False by simp
qed

lemma pfinite_INF_le_expectation_gpv:
  fixes fail  gpv f
  defines "c  min (INF xresults_gpv  gpv. f x) fail"
  assumes fin: "pfinite_gpv  gpv"
  shows "c  expectation_gpv fail  f gpv" (is "?lhs  ?rhs")
proof(cases "c > 0")
  case True
  have "c = c * expectation_gpv 1  (λ_. 1) gpv" using assms by(simp add: pgen_lossless_gpv_def)
  also have " = expectation_gpv c  (λ_. c) gpv" using fin True
    by(cases "c = ")(simp_all add: expectation_gpv_top_pfinite ennreal_top_mult expectation_gpv_cmult, simp add: pgen_lossless_gpv_def)
  also have "  ?rhs" by(rule expectation_gpv_mono_strong)(auto simp add: c_def min_def intro: INF_lower2)
  finally show ?thesis .
qed simp

lemma plossless_INF_le_expectation_gpv:
  fixes fail
  assumes "plossless_gpv  gpv" and " ⊢g gpv "
  shows "(INF xresults_gpv  gpv. f x)  expectation_gpv fail  f gpv" (is "?lhs  ?rhs")
proof -
  from assms have fin: "pfinite_gpv  gpv" and co: "colossless_gpv  gpv"
    by(simp_all add: plossless_iff_colossless_pfinite)
  have "?lhs  min ?lhs " by(simp add: min_def)
  also have "  expectation_gpv   f gpv" using fin by(rule pfinite_INF_le_expectation_gpv)
  also have " = ?rhs" using co by(simp add: expectation_gpv_cong_fail)
  finally show ?thesis .
qed


lemma expectation_gpv_le_inline:
  fixes ℐ'
  defines "expectation_gpv2  expectation_gpv 0 ℐ'"
  assumes callee: "s x. x  outs_ℐ   plossless_gpv ℐ' (callee s x)"
    and callee': "s x. x  outs_ℐ   results_gpv ℐ' (callee s x)  responses_ℐ  x × UNIV"
    and WT_gpv: " ⊢g gpv "
    and WT_callee: "s x. x  outs_ℐ   ℐ' ⊢g callee s x "
  shows "expectation_gpv 0  f gpv  expectation_gpv2 (λ(x, s). f x) (inline callee gpv s)"
  using WT_gpv
proof(induction arbitrary: gpv s rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    with step.prems have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    have "(INF rresponses_ℐ  out. expectation_gpv' (c r)) = + generat. (INF rresponses_ℐ  out. expectation_gpv' (c r)) measure_spmf (the_gpv (callee s out))"
      using WT_callee[OF out, of s] callee[OF out, of s]
      by(clarsimp simp add: measure_spmf.emeasure_eq_measure plossless_iff_colossless_pfinite colossless_gpv_lossless_spmfD lossless_weight_spmfD)
    also have "  + generat. (case generat of Pure (x, s') 
            + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')
         | IO out' rpv  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))
       measure_spmf (the_gpv (callee s out))"
    proof(rule nn_integral_mono_AE; simp split!: generat.split)
      fix x s'
      assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
      hence "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
      with callee'[OF out, of s] have x: "x  responses_ℐ  out" by blast
      hence "(INF rresponses_ℐ  out. expectation_gpv' (c r))  expectation_gpv' (c x)" by(rule INF_lower)
      also have "  expectation_gpv2 (λ(x, s). f x) (inline callee (c x) s')"
        by(rule step.IH)(rule WT_gpv_ContD[OF step.prems(1) IO x] step.prems|assumption)+
      also have " = + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')"
        unfolding expectation_gpv2_def
        by(subst expectation_gpv.simps)(auto simp add: inline_sel split_def o_def intro!: nn_integral_cong split: generat.split sum.split)
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    next
      fix out' rpv
      assume IO': "IO out' rpv  set_spmf (the_gpv (callee s out))"
      have "(INF rresponses_ℐ  out. expectation_gpv' (c r))  (INF (r, s')(r'responses_ℐ ℐ' out'. results_gpv ℐ' (rpv r')). expectation_gpv' (c r))"
        using IO' callee'[OF out, of s] by(intro INF_mono)(auto intro: results_gpv.IO)
      also have " = (INF r'responses_ℐ ℐ' out'. INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))"
        by(simp add: INF_UNION)
      also have "  (INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))"
      proof(rule INF_mono, rule bexI)
        fix r'
        assume r': "r'  responses_ℐ ℐ' out'"
        have "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  (INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv2 (λ(x, s). f x) (inline callee (c r) s'))"
          using IO IO' step.prems out callee'[OF out, of s] r'
          by(auto intro!: INF_mono rev_bexI step.IH dest: WT_gpv_ContD intro: results_gpv.IO)
        also have "   expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r')"
          unfolding expectation_gpv2_def using plossless_gpv_ContD[OF callee, OF out IO' r'] WT_callee[OF out, of s] IO' r'
          by(intro plossless_INF_le_expectation_gpv)(auto intro: WT_gpv_ContD)
        finally show "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  " .
      qed
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    qed
    also note calculation }
  then show ?case unfolding expectation_gpv2_def
    apply(rewrite expectation_gpv.simps)
    apply(rewrite inline_sel)
    apply(simp add: o_def pmf_map_spmf_None)
    apply(rewrite sum.case_distrib[where h="case_generat _ _"])
    apply(simp cong del: sum.case_cong_weak)
    apply(simp add: split_beta o_def cong del: sum.case_cong_weak)
    apply(rewrite inline1.simps)
    apply(rewrite measure_spmf_bind)
    apply(rewrite nn_integral_bind[where B="measure_spmf _"])
      apply simp
     apply(simp add: space_subprob_algebra)
    apply(rule nn_integral_mono_AE)
    apply(clarsimp split!: generat.split)
     apply(simp add: measure_spmf_return_spmf nn_integral_return)
    apply(rewrite measure_spmf_bind)
    apply(simp add: nn_integral_bind[where B="measure_spmf _"] space_subprob_algebra)
    apply(subst generat.case_distrib[where h="measure_spmf"])
    apply(subst generat.case_distrib[where h="λx. nn_integral x _"])
    apply(simp add: measure_spmf_return_spmf nn_integral_return split_def)
    done
qed

lemma plossless_inline:
  assumes lossless: "plossless_gpv  gpv"
    and WT: " ⊢g gpv "
    and callee: "s x. x  outs_ℐ   plossless_gpv ℐ' (callee s x)"
    and callee': "s x. x  outs_ℐ   results_gpv ℐ' (callee s x)  responses_ℐ  x × UNIV"
    and WT_callee: "s x. x  outs_ℐ   ℐ' ⊢g callee s x "
  shows "plossless_gpv ℐ' (inline callee gpv s)"
unfolding pgen_lossless_gpv_def
proof(rule antisym)
  have WT': "ℐ' ⊢g inline callee gpv s " using callee' WT_callee WT by(rule WT_gpv_inline)
  from expectation_gpv_const_le[OF WT', of 0 1]
  show "expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)  1" by(simp add: max_def)

  have "1 = expectation_gpv 0  (λ_. 1) gpv" using lossless by(simp add: pgen_lossless_gpv_def)
  also have "  expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)"
    by(rule expectation_gpv_le_inline[unfolded split_def]; rule callee callee' WT WT_callee)
  finally show "1  " .
qed

lemma plossless_exec_gpv:
  assumes lossless: "plossless_gpv  gpv"
    and WT: " ⊢g gpv "
    and callee: "s x. x  outs_ℐ   lossless_spmf (callee s x)"
    and callee': "s x. x  outs_ℐ   set_spmf (callee s x)  responses_ℐ  x × UNIV"
  shows "lossless_spmf (exec_gpv callee gpv s)"
proof -
  have "plossless_gpv ℐ_full (inline (λs x. lift_spmf (callee s x)) gpv s)"
    using lossless WT by(rule plossless_inline)(simp_all add: callee callee')
  from this[THEN plossless_gpv_lossless_spmfD] show ?thesis
    unfolding exec_gpv_conv_inline1 by(simp add: inline_sel)
qed

lemma expectation_gpv_ℐ_mono:
  defines "expectation_gpv'  expectation_gpv"
  assumes le: "  ℐ'"
    and WT: " ⊢g gpv "
  shows "expectation_gpv fail  f gpv  expectation_gpv' fail ℐ' f gpv"
  using WT
proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case step [unfolded expectation_gpv'_def]: (step expectation_gpv')
  show ?case unfolding expectation_gpv'_def
    by(subst expectation_gpv.simps)
      (clarsimp intro!: add_mono nn_integral_mono_AE INF_mono split: generat.split
        , auto intro!: bexI step add_mono nn_integral_mono_AE INF_mono split: generat.split dest: WT_gpvD[OF step.prems] intro!: step dest: responses_ℐ_mono[OF le])
qed

lemma pgen_lossless_gpv_mono:
  assumes *: "pgen_lossless_gpv fail  gpv"
    and le: "  ℐ'"
    and WT: " ⊢g gpv "
    and fail: "fail  1"
  shows "pgen_lossless_gpv fail ℐ' gpv"
  unfolding pgen_lossless_gpv_def
proof(rule antisym)
  from WT le have "ℐ' ⊢g gpv " by(rule WT_gpv_ℐ_mono)
  from expectation_gpv_const_le[OF this, of fail 1] fail
  show "expectation_gpv fail ℐ' (λ_. 1) gpv  1" by(simp add: max_def split: if_split_asm)
  from expectation_gpv_ℐ_mono[OF le WT, of fail "λ_. 1"] *
  show "expectation_gpv fail ℐ' (λ_. 1) gpv  1" by(simp add: pgen_lossless_gpv_def)
qed

lemma plossless_gpv_mono:
  " plossless_gpv  gpv;   ℐ';  ⊢g gpv    plossless_gpv ℐ' gpv"
  by(erule pgen_lossless_gpv_mono; simp)

lemma pfinite_gpv_mono:
  " pfinite_gpv  gpv;   ℐ';  ⊢g gpv    pfinite_gpv ℐ' gpv"
  by(erule pgen_lossless_gpv_mono; simp)

lemma pgen_lossless_gpv_parametric': includes lifting_syntax shows
  "((=) ===> rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) pgen_lossless_gpv pgen_lossless_gpv"
  unfolding pgen_lossless_gpv_def supply expectation_gpv_parametric'[transfer_rule] by transfer_prover

lemma pgen_lossless_gpv_parametric: includes lifting_syntax shows
  "((=) ===> rel_ℐ C (=) ===> rel_gpv A C ===> (=)) pgen_lossless_gpv pgen_lossless_gpv"
  using pgen_lossless_gpv_parametric'[of C "(=)" A] by(simp add: rel_gpv_conv_rel_gpv'')

lemma pgen_lossless_gpv_map_gpv_id [simp]:
  "pgen_lossless_gpv fail  (map_gpv f id gpv) = pgen_lossless_gpv fail  gpv"
  using pgen_lossless_gpv_parametric[of "BNF_Def.Grp UNIV id" "BNF_Def.Grp UNIV f"]
  unfolding gpv.rel_Grp
  by(auto simp add: eq_alt[symmetric] rel_ℐ_eq rel_fun_def Grp_iff)

context raw_converter_invariant begin

lemma expectation_gpv_le_inline:
  defines "expectation_gpv2  expectation_gpv 0 ℐ'"
  assumes callee: "s x.  x  outs_ℐ ; I s   plossless_gpv ℐ' (callee s x)"
    and WT_gpv: " ⊢g gpv "
    and I: "I s"
  shows "expectation_gpv 0  f gpv  expectation_gpv2 (λ(x, s). f x) (inline callee gpv s)"
  using WT_gpv I
proof(induction arbitrary: gpv s rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    with step.prems (1) have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    have "(INF rresponses_ℐ  out. expectation_gpv' (c r)) = + generat. (INF rresponses_ℐ  out. expectation_gpv' (c r)) measure_spmf (the_gpv (callee s out))"
      using WT_callee[OF out, of s] callee[OF out, of s] I s
      by(clarsimp simp add: measure_spmf.emeasure_eq_measure plossless_iff_colossless_pfinite colossless_gpv_lossless_spmfD lossless_weight_spmfD)
    also have "  + generat. (case generat of Pure (x, s') 
            + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')
         | IO out' rpv  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))
       measure_spmf (the_gpv (callee s out))"
    proof(rule nn_integral_mono_AE; simp split!: generat.split)
      fix x s'
      assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
      hence "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
      with results_callee[OF out, of s] I s have x: "x  responses_ℐ  out" and "I s'" by blast+
      from x have "(INF rresponses_ℐ  out. expectation_gpv' (c r))  expectation_gpv' (c x)" by(rule INF_lower)
      also have "  expectation_gpv2 (λ(x, s). f x) (inline callee (c x) s')"
        by(rule step.IH)(rule WT_gpv_ContD[OF step.prems(1) IO x] step.prems I s'|assumption)+
      also have " = + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')"
        unfolding expectation_gpv2_def
        by(subst expectation_gpv.simps)(auto simp add: inline_sel split_def o_def intro!: nn_integral_cong split: generat.split sum.split)
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    next
      fix out' rpv
      assume IO': "IO out' rpv  set_spmf (the_gpv (callee s out))"
      have "(INF rresponses_ℐ  out. expectation_gpv' (c r))  (INF (r, s')(r'responses_ℐ ℐ' out'. results_gpv ℐ' (rpv r')). expectation_gpv' (c r))"
        using IO' results_callee[OF out, of s] I s by(intro INF_mono)(auto intro: results_gpv.IO)
      also have " = (INF r'responses_ℐ ℐ' out'. INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))"
        by(simp add: INF_UNION)
      also have "  (INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))"
      proof(rule INF_mono, rule bexI)
        fix r'
        assume r': "r'  responses_ℐ ℐ' out'"
        have "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  (INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv2 (λ(x, s). f x) (inline callee (c r) s'))"
          using IO IO' step.prems out results_callee[OF out, of s] r'
          by(auto intro!: INF_mono rev_bexI step.IH dest: WT_gpv_ContD intro: results_gpv.IO)
        also have "   expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r')"
          unfolding expectation_gpv2_def using plossless_gpv_ContD[OF callee, OF out I s IO' r'] WT_callee[OF out I s] IO' r'
          by(intro plossless_INF_le_expectation_gpv)(auto intro: WT_gpv_ContD)
        finally show "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  " .
      qed
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    qed
    also note calculation }
  then show ?case unfolding expectation_gpv2_def
    apply(rewrite expectation_gpv.simps)
    apply(rewrite inline_sel)
    apply(simp add: o_def pmf_map_spmf_None)
    apply(rewrite sum.case_distrib[where h="case_generat _ _"])
    apply(simp cong del: sum.case_cong_weak)
    apply(simp add: split_beta o_def cong del: sum.case_cong_weak)
    apply(rewrite inline1.simps)
    apply(rewrite measure_spmf_bind)
    apply(rewrite nn_integral_bind[where B="measure_spmf _"])
      apply simp
     apply(simp add: space_subprob_algebra)
    apply(rule nn_integral_mono_AE)
    apply(clarsimp split!: generat.split)
     apply(simp add: measure_spmf_return_spmf nn_integral_return)
    apply(rewrite measure_spmf_bind)
    apply(simp add: nn_integral_bind[where B="measure_spmf _"] space_subprob_algebra)
    apply(subst generat.case_distrib[where h="measure_spmf"])
    apply(subst generat.case_distrib[where h="λx. nn_integral x _"])
    apply(simp add: measure_spmf_return_spmf nn_integral_return split_def)
    done
qed

lemma plossless_inline:
  assumes lossless: "plossless_gpv  gpv"
    and WT: " ⊢g gpv "
    and callee: "s x.  I s; x  outs_ℐ    plossless_gpv ℐ' (callee s x)"
    and I: "I s"
  shows "plossless_gpv ℐ' (inline callee gpv s)"
  unfolding pgen_lossless_gpv_def
proof(rule antisym)
  have WT': "ℐ' ⊢g inline callee gpv s " using WT I by(rule WT_gpv_inline_invar)
  from expectation_gpv_const_le[OF WT', of 0 1]
  show "expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)  1" by(simp add: max_def)

  have "1 = expectation_gpv 0  (λ_. 1) gpv" using lossless by(simp add: pgen_lossless_gpv_def)
  also have "  expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)"
    by(rule expectation_gpv_le_inline[unfolded split_def]; rule callee I WT)
  finally show "1  " .
qed

end

lemma expectation_left_gpv [simp]:
  "expectation_gpv fail (  ℐ') f (left_gpv gpv) = expectation_gpv fail  f gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv' expectation_gpv'')
  show ?case
    by (auto simp add: pmf_map_spmf_None o_def case_map_generat image_comp
      split: generat.split intro!: nn_integral_cong_AE INF_cong step.IH)
qed

lemma expectation_right_gpv [simp]:
  "expectation_gpv fail (  ℐ') f (right_gpv gpv) = expectation_gpv fail ℐ' f gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv' expectation_gpv'')
  show ?case
    by (auto simp add: pmf_map_spmf_None o_def case_map_generat image_comp
      split: generat.split intro!: nn_integral_cong_AE INF_cong step.IH)
qed

lemma pgen_lossless_left_gpv [simp]: "pgen_lossless_gpv fail (  ℐ') (left_gpv gpv) = pgen_lossless_gpv fail  gpv"
  by(simp add: pgen_lossless_gpv_def)

lemma pgen_lossless_right_gpv [simp]: "pgen_lossless_gpv fail (  ℐ') (right_gpv gpv) = pgen_lossless_gpv fail ℐ' gpv"
  by(simp add: pgen_lossless_gpv_def)

lemma (in raw_converter_invariant) expectation_gpv_le_inline_invariant:
  defines "expectation_gpv2  expectation_gpv 0 ℐ'"
  assumes callee: "s x.  x  outs_ℐ ; I s   plossless_gpv ℐ' (callee s x)"
    and WT_gpv: " ⊢g gpv "
    and I: "I s"
  shows "expectation_gpv 0  f gpv  expectation_gpv2 (λ(x, s). f x) (inline callee gpv s)"
  using WT_gpv I
proof(induction arbitrary: gpv s rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step expectation_gpv')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    with step.prems(1) have out: "out  outs_ℐ " by(rule WT_gpv_OutD)
    have "(INF rresponses_ℐ  out. expectation_gpv' (c r)) = + generat. (INF rresponses_ℐ  out. expectation_gpv' (c r)) measure_spmf (the_gpv (callee s out))"
      using WT_callee[OF out, of s] callee[OF out, of s] step.prems(2)
      by(clarsimp simp add: measure_spmf.emeasure_eq_measure plossless_iff_colossless_pfinite colossless_gpv_lossless_spmfD lossless_weight_spmfD)
    also have "  + generat. (case generat of Pure (x, s') 
            + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')
         | IO out' rpv  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))
       measure_spmf (the_gpv (callee s out))"
    proof(rule nn_integral_mono_AE; simp split!: generat.split)
      fix x s'
      assume Pure: "Pure (x, s')  set_spmf (the_gpv (callee s out))"
      hence "(x, s')  results_gpv ℐ' (callee s out)" by(rule results_gpv.Pure)
      with results_callee[OF out step.prems(2)] have x: "x  responses_ℐ  out" and s': "I s'" by blast+
      from this(1) have "(INF rresponses_ℐ  out. expectation_gpv' (c r))  expectation_gpv' (c x)" by(rule INF_lower)
      also have "  expectation_gpv2 (λ(x, s). f x) (inline callee (c x) s')"
        by(rule step.IH)(rule WT_gpv_ContD[OF step.prems(1) IO x] step.prems s'|assumption)+
      also have " = + xx. (case xx of Inl (x, _)  f x 
               | Inr (out', callee', rpv)  INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r, s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (rpv r) s')) (callee' r'))
            measure_spmf (inline1 callee (c x) s')"
        unfolding expectation_gpv2_def
        by(subst expectation_gpv.simps)(auto simp add: inline_sel split_def o_def intro!: nn_integral_cong split: generat.split sum.split)
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    next
      fix out' rpv
      assume IO': "IO out' rpv  set_spmf (the_gpv (callee s out))"
      have "(INF rresponses_ℐ  out. expectation_gpv' (c r))  (INF (r, s')(r'responses_ℐ ℐ' out'. results_gpv ℐ' (rpv r')). expectation_gpv' (c r))"
        using IO' results_callee[OF out step.prems(2)] by(intro INF_mono)(auto intro: results_gpv.IO)
      also have " = (INF r'responses_ℐ ℐ' out'. INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))"
        by(simp add: INF_UNION)
      also have "  (INF r'responses_ℐ ℐ' out'. expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r'))"
      proof(rule INF_mono, rule bexI)
        fix r'
        assume r': "r'  responses_ℐ ℐ' out'"
        have "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  (INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv2 (λ(x, s). f x) (inline callee (c r) s'))"
          using IO IO' step.prems out results_callee[OF out, of s] r'
          by(auto intro!: INF_mono rev_bexI step.IH dest: WT_gpv_ContD intro: results_gpv.IO)
        also have "   expectation_gpv 0 ℐ' (λ(r', s'). expectation_gpv 0 ℐ' (λ(x, s). f x) (inline callee (c r') s')) (rpv r')"
          unfolding expectation_gpv2_def using plossless_gpv_ContD[OF callee, OF out step.prems(2) IO' r'] WT_callee[OF out step.prems(2)] IO' r'
          by(intro plossless_INF_le_expectation_gpv)(auto intro: WT_gpv_ContD)
        finally show "(INF (r, s')results_gpv ℐ' (rpv r'). expectation_gpv' (c r))  " .
      qed
      finally show "(INF rresponses_ℐ  out. expectation_gpv' (c r))  " .
    qed
    also note calculation }
  then show ?case unfolding expectation_gpv2_def
    apply(rewrite expectation_gpv.simps)
    apply(rewrite inline_sel)
    apply(simp add: o_def pmf_map_spmf_None)
    apply(rewrite sum.case_distrib[where h="case_generat _ _"])
    apply(simp cong del: sum.case_cong_weak)
    apply(simp add: split_beta o_def cong del: sum.case_cong_weak)
    apply(rewrite inline1.simps)
    apply(rewrite measure_spmf_bind)
    apply(rewrite nn_integral_bind[where B="measure_spmf _"])
      apply simp
     apply(simp add: space_subprob_algebra)
    apply(rule nn_integral_mono_AE)
    apply(clarsimp split!: generat.split)
     apply(simp add: measure_spmf_return_spmf nn_integral_return)
    apply(rewrite measure_spmf_bind)
    apply(simp add: nn_integral_bind[where B="measure_spmf _"] space_subprob_algebra)
    apply(subst generat.case_distrib[where h="measure_spmf"])
    apply(subst generat.case_distrib[where h="λx. nn_integral x _"])
    apply(simp add: measure_spmf_return_spmf nn_integral_return split_def)
    done
qed

lemma (in raw_converter_invariant) plossless_inline_invariant:
  assumes lossless: "plossless_gpv  gpv"
    and WT: " ⊢g gpv "
    and callee: "s x.  x  outs_ℐ ; I s   plossless_gpv ℐ' (callee s x)"
    and I: "I s"
  shows "plossless_gpv ℐ' (inline callee gpv s)"
  unfolding pgen_lossless_gpv_def
proof(rule antisym)
  have WT': "ℐ' ⊢g inline callee gpv s " using WT I by(rule WT_gpv_inline_invar)
  from expectation_gpv_const_le[OF WT', of 0 1]
  show "expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)  1" by(simp add: max_def)

  have "1 = expectation_gpv 0  (λ_. 1) gpv" using lossless by(simp add: pgen_lossless_gpv_def)
  also have "  expectation_gpv 0 ℐ' (λ_. 1) (inline callee gpv s)"
    by(rule expectation_gpv_le_inline[unfolded split_def]; rule callee WT WT_callee I)
  finally show "1  " .
qed

context callee_invariant_on begin

lemma raw_converter_invariant: "raw_converter_invariant  ℐ' (λs x. lift_spmf (callee s x)) I"
  by(unfold_locales)(auto dest: callee_invariant WT_callee WT_calleeD)

lemma (in callee_invariant_on) plossless_exec_gpv:
  assumes lossless: "plossless_gpv  gpv"
    and WT: " ⊢g gpv "
    and callee: "s x.  x  outs_ℐ ; I s   lossless_spmf (callee s x)"
    and I: "I s"
  shows "lossless_spmf (exec_gpv callee gpv s)"
proof -
  interpret raw_converter_invariant  ℐ' "λs x. lift_spmf (callee s x)" I for ℐ'
    by(rule raw_converter_invariant)
  have "plossless_gpv ℐ_full (inline (λs x. lift_spmf (callee s x)) gpv s)"
    using lossless WT by(rule plossless_inline)(simp_all add: callee I)
  from this[THEN plossless_gpv_lossless_spmfD] show ?thesis
    unfolding exec_gpv_conv_inline1 by(simp add: inline_sel)
qed

end

lemma expectation_gpv_mk_lossless_gpv:
  fixes  y
  defines "rhs  expectation_gpv 0  (λ_. y)"
  assumes WT: "ℐ' ⊢g gpv "
    and outs: "outs_ℐ  = outs_ℐ ℐ'"
  shows "expectation_gpv 0 ℐ' (λ_. y) gpv  rhs (mk_lossless_gpv (responses_ℐ ℐ') x gpv)"
  using WT
proof(induction arbitrary: gpv rule: expectation_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case step [unfolded rhs_def]: (step expectation_gpv')
  show ?case using step.prems outs unfolding rhs_def
    apply(subst expectation_gpv.simps)
    apply(clarsimp intro!: nn_integral_mono_AE INF_mono split!: generat.split if_split)
    subgoal
      by(frule (1) WT_gpv_OutD)(auto simp add: in_outs_ℐ_iff_responses_ℐ intro!: bexI step.IH[unfolded rhs_def] dest: WT_gpv_ContD)
    apply(frule (1) WT_gpv_OutD; clarsimp simp add: in_outs_ℐ_iff_responses_ℐ ex_in_conv[symmetric])
    subgoal for out c input input'
      using step.hyps[of "c input'"] expectation_gpv_const_le[of ℐ' "c input'" 0 y]
      by- (drule (2) WT_gpv_ContD, fastforce intro: rev_bexI simp add: max_def)
    done
qed

lemma plossless_gpv_mk_lossless_gpv:
  assumes "plossless_gpv  gpv"
    and " ⊢g gpv "
    and "outs_ℐ  = outs_ℐ ℐ'"
  shows "plossless_gpv ℐ' (mk_lossless_gpv (responses_ℐ ) x gpv)"
  using assms expectation_gpv_mk_lossless_gpv[OF assms(2), of ℐ' 1 x]
  unfolding pgen_lossless_gpv_def
  by -(rule antisym[OF expectation_gpv_const_le[THEN order_trans]]; simp add: WT_gpv_mk_lossless_gpv)

lemma (in callee_invariant_on) exec_gpv_mk_lossless_gpv:
  assumes " ⊢g gpv "
    and "I s"
  shows "exec_gpv callee (mk_lossless_gpv (responses_ℐ ) x gpv) s = exec_gpv callee gpv s"
  using assms
proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step exec_gpv')
  show ?case using step.prems WT_gpv_OutD[OF step.prems(1)]
    by(clarsimp simp add: bind_map_spmf intro!: bind_spmf_cong[OF refl] split!: generat.split if_split)
      (force intro!: step.IH dest: WT_callee[THEN WT_calleeD] WT_gpv_OutD callee_invariant WT_gpv_ContD)+
qed


lemma expectation_gpv_map_gpv' [simp]:
  "expectation_gpv fail  f (map_gpv' g h k gpv) =
   expectation_gpv fail (map_ℐ h k ) (f  g) gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions expectation_gpv.mono expectation_gpv.mono expectation_gpv_def expectation_gpv_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step exp1 exp2)
  have "pmf (the_gpv (map_gpv' g h k gpv)) None = pmf (the_gpv gpv) None"
    by(simp add: pmf_map_spmf_None)
  then show ?case 
    by simp
      (auto simp add: nn_integral_measure_spmf step.IH image_comp
        split: generat.split intro!: nn_integral_cong)
qed

lemma plossless_gpv_map_gpv' [simp]:
  "pgen_lossless_gpv b  (map_gpv' f g h gpv)  pgen_lossless_gpv b (map_ℐ g h ) gpv"
  unfolding pgen_lossless_gpv_def by(simp add: o_def)

end

Theory GPV_Bisim

(* Title: GPV_Bisim.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory GPV_Bisim imports
  GPV_Expectation
begin

subsection ‹Bisimulation for oracles›

text ‹Bisimulation is a consequence of parametricity›

lemma exec_gpv_oracle_bisim':
  assumes *: "X s1 s2"
  and bisim: "s1 s2 x. X s1 s2  rel_spmf (λ(a, s1') (b, s2'). a = b  X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
  shows "rel_spmf (λ(a, s1') (b, s2'). a = b  X s1' s2') (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
by(rule exec_gpv_parametric[of X "(=)" "(=)", unfolded gpv.rel_eq rel_prod_conv, THEN rel_funD, THEN rel_funD, THEN rel_funD, OF rel_funI refl, OF rel_funI *])(simp add: bisim)

lemma exec_gpv_oracle_bisim:
  assumes *: "X s1 s2"
  and bisim: "s1 s2 x. X s1 s2  rel_spmf (λ(a, s1') (b, s2'). a = b  X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
  and R: "x s1' s2'.  X s1' s2'; (x, s1')  set_spmf (exec_gpv oracle1 gpv s1); (x, s2')  set_spmf (exec_gpv oracle2 gpv s2)   R (x, s1') (x, s2')"
  shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
apply(rule spmf_rel_mono_strong)
apply(rule exec_gpv_oracle_bisim'[OF * bisim])
apply(auto dest: R)
done

lemma run_gpv_oracle_bisim:
  assumes  "X s1 s2"
  and "s1 s2 x. X s1 s2  rel_spmf (λ(a, s1') (b, s2'). a = b  X s1' s2') (oracle1 s1 x) (oracle2 s2 x)"
  shows "run_gpv oracle1 gpv s1 = run_gpv oracle2 gpv s2"
using exec_gpv_oracle_bisim'[OF assms]
by(fold spmf_rel_eq)(fastforce simp add: spmf_rel_map intro: rel_spmf_mono)

context
  fixes joint_oracle :: "('s1 × 's2)  'a  (('b × 's1) × ('b × 's2)) spmf"
  and oracle1 :: "'s1  'a  ('b × 's1) spmf"
  and bad1 :: "'s1  bool"
  and oracle2 :: "'s2  'a  ('b × 's2) spmf"
  and bad2 :: "'s2  bool"
begin

partial_function (spmf) exec_until_bad :: "('x, 'a, 'b) gpv  's1  's2  (('x × 's1) × ('x × 's2)) spmf"
where
  "exec_until_bad gpv s1 s2 = 
  (if bad1 s1  bad2 s2 then pair_spmf (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)
  else bind_spmf (the_gpv gpv) (λgenerat.
     case generat of Pure x  return_spmf ((x, s1), (x, s2))
     | IO out f  bind_spmf (joint_oracle (s1, s2) out) (λ((x, s1'), (y, s2')). 
       if bad1 s1'  bad2 s2' then pair_spmf (exec_gpv oracle1 (f x) s1') (exec_gpv oracle2 (f y) s2')
       else exec_until_bad (f x) s1' s2')))"

lemma exec_until_bad_fixp_induct [case_names adm bottom step]:
  assumes "ccpo.admissible (fun_lub lub_spmf) (fun_ord (ord_spmf (=))) (λf. P (λgpv s1 s2. f ((gpv, s1), s2)))"
  and "P (λ_ _ _. return_pmf None)"
  and "exec_until_bad'. P exec_until_bad'  
     P (λgpv s1 s2. if bad1 s1  bad2 s2 then pair_spmf (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)
     else bind_spmf (the_gpv gpv) (λgenerat.
     case generat of Pure x  return_spmf ((x, s1), (x, s2))
     | IO out f  bind_spmf (joint_oracle (s1, s2) out) (λ((x, s1'), (y, s2')). 
       if bad1 s1'  bad2 s2' then pair_spmf (exec_gpv oracle1 (f x) s1') (exec_gpv oracle2 (f y) s2') 
       else exec_until_bad' (f x) s1' s2')))"
  shows "P exec_until_bad"
using assms by(rule exec_until_bad.fixp_induct[unfolded curry_conv[abs_def]])

end

lemma exec_gpv_oracle_bisim_bad_plossless:
  fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1  's2  bool"
  and oracle1 :: "'s1  'a  ('b × 's1) spmf"
  and oracle2 :: "'s2  'a  ('b × 's2) spmf"
  assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
  and bad: "bad1 s1 = bad2 s2"
  and bisim: "s1 s2 x.  X s1 s2; x  outs_ℐ    rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
  and bad_sticky1: "s2. bad2 s2  callee_invariant_on oracle1 (λs1. bad1 s1  X_bad s1 s2) "
  and bad_sticky2: "s1. bad1 s1  callee_invariant_on oracle2 (λs2. bad2 s2  X_bad s1 s2) "
  and lossless1: "s1 x.  bad1 s1; x  outs_ℐ    lossless_spmf (oracle1 s1 x)"
  and lossless2: "s2 x.  bad2 s2; x  outs_ℐ    lossless_spmf (oracle2 s2 x)"
  and lossless: "plossless_gpv  gpv"
  and WT_oracle1: "s1.  ⊢c oracle1 s1 " (* stronger than the invariants above because unconditional *)
  and WT_oracle2: "s2.  ⊢c oracle2 s2 "
  and WT_gpv: " ⊢g gpv "
  shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
  (is "rel_spmf ?R ?p ?q")
proof -
  let ?R' = "λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')"
  from bisim have "s1 s2. x  outs_ℐ . X s1 s2  rel_spmf ?R' (oracle1 s1 x) (oracle2 s2 x)" by blast
  then obtain joint_oracle
    where oracle1 [symmetric]: "s1 s2 x.  X s1 s2; x  outs_ℐ    map_spmf fst (joint_oracle s1 s2 x) = oracle1 s1 x"
    and oracle2 [symmetric]: "s1 s2 x.  X s1 s2; x  outs_ℐ    map_spmf snd (joint_oracle s1 s2 x) = oracle2 s2 x"
    and 3 [rotated 2]: "s1 s2 x y y' s1' s2'.  X s1 s2; x  outs_ℐ ; ((y, s1'), (y', s2'))  set_spmf (joint_oracle s1 s2 x) 
       bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else y = y'  X s1' s2')"
    apply atomize_elim
    apply(unfold rel_spmf_simps all_conj_distrib[symmetric] all_simps(6) imp_conjR[symmetric])
    apply(subst choice_iff[symmetric] ex_simps(6))+
    apply fastforce
    done
  let ?joint_oracle = "λ(s1, s2). joint_oracle s1 s2"
  let ?pq = "exec_until_bad ?joint_oracle oracle1 bad1 oracle2 bad2 gpv s1 s2"

  have setD: "s1 s2 x y y' s1' s2'.  X s1 s2; x  outs_ℐ ; ((y, s1'), (y', s2'))  set_spmf (joint_oracle s1 s2 x) 
     (y, s1')  set_spmf (oracle1 s1 x)  (y', s2')  set_spmf (oracle2 s2 x)"
    unfolding oracle1 oracle2 by(auto intro: rev_image_eqI)
  show ?thesis
  proof
    show "map_spmf fst ?pq = exec_gpv oracle1 gpv s1"
    proof(rule spmf.leq_antisym)
      show "ord_spmf (=) (map_spmf fst ?pq) (exec_gpv oracle1 gpv s1)" using * bad WT_gpv lossless
      proof(induction arbitrary: s1 s2 gpv rule: exec_until_bad_fixp_induct)
        case adm show ?case by simp
        case bottom show ?case by simp
        case (step exec_until_bad')
        show ?case
        proof(cases "bad2 s2")
          case True
          then have "weight_spmf (exec_gpv oracle2 gpv s2) = 1"
            using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1 gpv s2]
              step.prems weight_spmf_le_1[of "exec_gpv oracle2 gpv s2"]
            by(simp add: pgen_lossless_gpv_def weight_gpv_def)
          then show ?thesis using True by simp
        next
          case False
          hence "¬ bad1 s1" using step.prems(2) by simp
          moreover {
            fix out c r1 s1' r2 s2'
            assume IO: "IO out c  set_spmf (the_gpv gpv)"
              and joint: "((r1, s1'), (r2, s2'))  set_spmf (joint_oracle s1 s2 out)"
            from step.prems(3) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
            from setD[OF _ out joint] step.prems(1) False
            have 1: "(r1, s1')  set_spmf (oracle1 s1 out)"
              and 2: "(r2, s2')  set_spmf (oracle2 s2 out)" by simp_all
            hence r1: "r1  responses_ℐ  out" and r2: "r2  responses_ℐ  out"
              using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
            have *: "plossless_gpv  (c r2)" using step.prems(4) IO r2 step.prems(3)
              by(rule plossless_gpv_ContD)
            then have "bad2 s2'  weight_spmf (exec_gpv oracle2 (c r2) s2') = 1"
              and "¬ bad2 s2'  ord_spmf (=) (map_spmf fst (exec_until_bad' (c r2) s1' s2')) (exec_gpv oracle1 (c r2) s1')"
              using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1' "c r2" s2'] 
                weight_spmf_le_1[of "exec_gpv oracle2 (c r2) s2'"] WT_gpv_ContD[OF step.prems(3) IO r2]
                3[OF joint _ out] step.prems(1) False
              by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
          ultimately show ?thesis using False step.prems(1)
            by(rewrite in "ord_spmf _ _ " exec_gpv.simps)
              (fastforce simp add: split_def bind_map_spmf map_spmf_bind_spmf oracle1 WT_gpv_OutD[OF step.prems(3)] intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
        qed
      qed
      show "ord_spmf (=) (exec_gpv oracle1 gpv s1) (map_spmf fst ?pq)" using * bad WT_gpv lossless
      proof(induction arbitrary: gpv s1 s2 rule: exec_gpv_fixp_induct_strong)
        case adm show ?case by simp
        case bottom show ?case by simp
        case (step exec_gpv')
        then show ?case
        proof(cases "bad2 s2")
          case True
          then have "weight_spmf (exec_gpv oracle2 gpv s2) = 1"
            using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1 gpv s2]
              step.prems weight_spmf_le_1[of "exec_gpv oracle2 gpv s2"]
            by(simp add: pgen_lossless_gpv_def weight_gpv_def)
          then show ?thesis using True
            by(rewrite exec_until_bad.simps; rewrite exec_gpv.simps)
              (clarsimp intro!: ord_spmf_bind_reflI split!: generat.split simp add: step.hyps)
        next
          case False
          hence "¬ bad1 s1" using step.prems(2) by simp
          moreover {
            fix out c r1 s1' r2 s2'
            assume IO: "IO out c  set_spmf (the_gpv gpv)"
              and joint: "((r1, s1'), (r2, s2'))  set_spmf (joint_oracle s1 s2 out)"
            from step.prems(3) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
            from setD[OF _ out joint] step.prems(1) False
            have 1: "(r1, s1')  set_spmf (oracle1 s1 out)"
              and 2: "(r2, s2')  set_spmf (oracle2 s2 out)" by simp_all
            hence r1: "r1  responses_ℐ  out" and r2: "r2  responses_ℐ  out"
              using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
            have *: "plossless_gpv  (c r2)" using step.prems(4) IO r2 step.prems(3)
              by(rule plossless_gpv_ContD)
            then have "bad2 s2'  weight_spmf (exec_gpv oracle2 (c r2) s2') = 1" 
              and "¬ bad2 s2'  ord_spmf (=) (exec_gpv' (c r2) s1') (map_spmf fst (exec_until_bad (λ(x, y). joint_oracle x y) oracle1 bad1 oracle2 bad2 (c r2) s1' s2'))"
              using callee_invariant_on.weight_exec_gpv[OF bad_sticky2 lossless2, of s1' "c r2" s2'] 
                weight_spmf_le_1[of "exec_gpv oracle2 (c r2) s2'"] WT_gpv_ContD[OF step.prems(3) IO r2]
                3[OF joint _ out] step.prems(1) False
              by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
          ultimately show ?thesis using False step.prems(1)
            by(rewrite exec_until_bad.simps)
              (fastforce simp add: map_spmf_bind_spmf WT_gpv_OutD[OF step.prems(3)] oracle1 bind_map_spmf step.hyps intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
        qed
      qed
    qed

    show "map_spmf snd ?pq = exec_gpv oracle2 gpv s2"
    proof(rule spmf.leq_antisym)
      show "ord_spmf (=) (map_spmf snd ?pq) (exec_gpv oracle2 gpv s2)" using * bad WT_gpv lossless
      proof(induction arbitrary: s1 s2 gpv rule: exec_until_bad_fixp_induct)
        case adm show ?case by simp
        case bottom show ?case by simp
        case (step exec_until_bad')
        show ?case
        proof(cases "bad2 s2")
          case True
          then have "weight_spmf (exec_gpv oracle1 gpv s1) = 1"
            using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2 gpv s1]
              step.prems weight_spmf_le_1[of "exec_gpv oracle1 gpv s1"]
            by(simp add: pgen_lossless_gpv_def weight_gpv_def)
          then show ?thesis using True by simp
        next
          case False
          hence "¬ bad1 s1" using step.prems(2) by simp
          moreover {
            fix out c r1 s1' r2 s2'
            assume IO: "IO out c  set_spmf (the_gpv gpv)"
              and joint: "((r1, s1'), (r2, s2'))  set_spmf (joint_oracle s1 s2 out)"
            from step.prems(3) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
            from setD[OF _ out joint] step.prems(1) False
            have 1: "(r1, s1')  set_spmf (oracle1 s1 out)"
              and 2: "(r2, s2')  set_spmf (oracle2 s2 out)" by simp_all
            hence r1: "r1  responses_ℐ  out" and r2: "r2  responses_ℐ  out"
              using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
            have *: "plossless_gpv  (c r1)" using step.prems(4) IO r1 step.prems(3)
              by(rule plossless_gpv_ContD)
            then have "bad2 s2'  weight_spmf (exec_gpv oracle1 (c r1) s1') = 1"
              and "¬ bad2 s2'  ord_spmf (=) (map_spmf snd (exec_until_bad' (c r2) s1' s2')) (exec_gpv oracle2 (c r2) s2')"
              using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2' "c r1" s1'] 
                weight_spmf_le_1[of "exec_gpv oracle1 (c r1) s1'"] WT_gpv_ContD[OF step.prems(3) IO r1]
                3[OF joint _ out] step.prems(1) False
              by(simp_all add: pgen_lossless_gpv_def weight_gpv_def step.IH) }
          ultimately show ?thesis using False step.prems(1)
            by(rewrite in "ord_spmf _ _ " exec_gpv.simps)
              (fastforce simp add: split_def bind_map_spmf map_spmf_bind_spmf oracle2 WT_gpv_OutD[OF step.prems(3)] intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
        qed
      qed
      show "ord_spmf (=) (exec_gpv oracle2 gpv s2) (map_spmf snd ?pq)" using * bad WT_gpv lossless
      proof(induction arbitrary: gpv s1 s2 rule: exec_gpv_fixp_induct_strong)
        case adm show ?case by simp
        case bottom show ?case by simp
        case (step exec_gpv')
        then show ?case
        proof(cases "bad2 s2")
          case True
          then have "weight_spmf (exec_gpv oracle1 gpv s1) = 1"
            using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2 gpv s1]
              step.prems weight_spmf_le_1[of "exec_gpv oracle1 gpv s1"]
            by(simp add: pgen_lossless_gpv_def weight_gpv_def)
          then show ?thesis using True
            by(rewrite exec_until_bad.simps; subst (2) exec_gpv.simps)
              (clarsimp intro!: ord_spmf_bind_reflI split!: generat.split simp add: step.hyps)
        next
          case False
          hence "¬ bad1 s1" using step.prems(2) by simp
          moreover {
            fix out c r1 s1' r2 s2'
            assume IO: "IO out c  set_spmf (the_gpv gpv)"
              and joint: "((r1, s1'), (r2, s2'))  set_spmf (joint_oracle s1 s2 out)"
            from step.prems(3) IO have out: "out  outs_ℐ " by(rule WT_gpvD)
            from setD[OF _ out joint] step.prems(1) False
            have 1: "(r1, s1')  set_spmf (oracle1 s1 out)"
              and 2: "(r2, s2')  set_spmf (oracle2 s2 out)" by simp_all
            hence r1: "r1  responses_ℐ  out" and r2: "r2  responses_ℐ  out"
              using WT_oracle1 WT_oracle2 out by(blast dest: WT_calleeD)+
            have *: "plossless_gpv  (c r1)" using step.prems(4) IO r1 step.prems(3)
              by(rule plossless_gpv_ContD)
            then have "bad2 s2'  weight_spmf (exec_gpv oracle1 (c r1) s1') = 1" 
              and "¬ bad2 s2'  ord_spmf (=) (exec_gpv' (c r2) s2') (map_spmf snd (exec_until_bad (λ(x, y). joint_oracle x y) oracle1 bad1 oracle2 bad2 (c r2) s1' s2'))"
              using callee_invariant_on.weight_exec_gpv[OF bad_sticky1 lossless1, of s2' "c r1" s1'] 
                weight_spmf_le_1[of "exec_gpv oracle1 (c r1) s1'"] WT_gpv_ContD[OF step.prems(3) IO r1]
                3[OF joint _ out] step.prems(1) False
              by(simp_all add: pgen_lossless_gpv_def step.IH weight_gpv_def) }
          ultimately show ?thesis using False step.prems(1)
            by(rewrite exec_until_bad.simps)
              (fastforce simp add: map_spmf_bind_spmf WT_gpv_OutD[OF step.prems(3)] oracle2 bind_map_spmf step.hyps intro!: ord_spmf_bind_reflI split!: generat.split dest: 3)
        qed
      qed
    qed

    have "set_spmf ?pq  {(as1, bs2). ?R' as1 bs2}" using * bad WT_gpv
    proof(induction arbitrary: gpv s1 s2 rule: exec_until_bad_fixp_induct)
      case adm show ?case by(intro cont_intro ccpo_class.admissible_leI)
      case bottom show ?case by simp
      case step
      have switch: "set_spmf (exec_gpv oracle1 (c r1) s1') × set_spmf (exec_gpv oracle2 (c r2) s2')
             {((a, s1'), b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')}"
        if "¬ bad1 s1" " ⊢g gpv " "¬ bad2 s2" and X: "X s1 s2" and out: "IO out c  set_spmf (the_gpv gpv)"
        and joint: "((r1, s1'), (r2, s2'))  set_spmf (joint_oracle s1 s2 out)" 
        and bad2: "bad2 s2'"
        for out c r1 s1' r2 s2'
      proof(clarify; rule conjI)
        from step.prems(3) out have outs: "out  outs_ℐ " by(rule WT_gpv_OutD)
        from bad2 3[OF joint X this] have bad1: "bad1 s1'  X_bad s1' s2'" by simp_all

        have s1': "(r1, s1')  set_spmf (oracle1 s1 out)" and s2': "(r2, s2')  set_spmf (oracle2 s2 out)"
          using setD[OF X outs joint] by simp_all
        have resp: "r1  responses_ℐ  out" using WT_oracle1 s1' outs by(rule WT_calleeD)
        with step.prems(3) out have WT1: " ⊢g c r1 " by(rule WT_gpv_ContD)
        have resp: "r2  responses_ℐ  out" using WT_oracle2 s2' outs by(rule WT_calleeD)
        with step.prems(3) out have WT2: " ⊢g c r2 " by(rule WT_gpv_ContD)

        fix r1' s1'' r2' s2''
        assume s1'': "(r1', s1'')  set_spmf (exec_gpv oracle1 (c r1) s1')"
          and s2'': "(r2', s2'')  set_spmf (exec_gpv oracle2 (c r2) s2')"
        have *: "bad1 s1''  X_bad s1'' s2'" using bad2 s1'' bad1 WT1
          by(rule callee_invariant_on.exec_gpv_invariant[OF bad_sticky1])
        have "bad2 s2''  X_bad s1'' s2''" using _ s2'' _ WT2
          by(rule callee_invariant_on.exec_gpv_invariant[OF bad_sticky2])(simp_all add: bad2 *)
        then show "bad1 s1'' = bad2 s2''" "if bad2 s2'' then X_bad s1'' s2'' else r1' = r2'  X s1'' s2''"
          using * by(simp_all)
      qed
      show ?case using step.prems
        apply(clarsimp simp add: bind_UNION step.IH 3 WT_gpv_OutD WT_gpv_ContD del: subsetI intro!: UN_least split: generat.split if_split_asm)
        subgoal by(auto 4 3 dest: callee_invariant_on.exec_gpv_invariant[OF bad_sticky1, rotated] callee_invariant_on.exec_gpv_invariant[OF bad_sticky2, rotated] 3)
        apply(intro strip conjI)
        subgoal by(drule (6) switch) auto
        subgoal by(auto 4 3 intro!: step.IH[THEN order.trans] del: subsetI dest: 3 setD[rotated 2] simp add: WT_gpv_OutD WT_gpv_ContD intro: WT_gpv_ContD intro!: WT_calleeD[OF WT_oracle1])
        done
    qed
    then show "x y. (x, y)  set_spmf ?pq  ?R x y" by auto
  qed
qed

lemma exec_gpv_oracle_bisim_bad':
  fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1  's2  bool"
  and oracle1 :: "'s1  'a  ('b × 's1) spmf"
  and oracle2 :: "'s2  'a  ('b × 's2) spmf"
  assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
  and bad: "bad1 s1 = bad2 s2"
  and bisim: "s1 s2 x.  X s1 s2; x  outs_ℐ    rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
  and bad_sticky1: "s2. bad2 s2  callee_invariant_on oracle1 (λs1. bad1 s1  X_bad s1 s2) "
  and bad_sticky2: "s1. bad1 s1  callee_invariant_on oracle2 (λs2. bad2 s2  X_bad s1 s2) "
  and lossless1: "s1 x.  bad1 s1; x  outs_ℐ    lossless_spmf (oracle1 s1 x)"
  and lossless2: "s2 x.  bad2 s2; x  outs_ℐ    lossless_spmf (oracle2 s2 x)"
  and lossless: "lossless_gpv  gpv"
  and WT_oracle1: "s1.  ⊢c oracle1 s1 " (* stronger than the invariants above because unconditional *)
  and WT_oracle2: "s2.  ⊢c oracle2 s2 "
  and WT_gpv: " ⊢g gpv "
  shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using assms(1-7) lossless_imp_plossless_gpv[OF lossless WT_gpv] assms(9-)
by(rule exec_gpv_oracle_bisim_bad_plossless)

lemma exec_gpv_oracle_bisim_bad_invariant:
  fixes s1 :: 's1 and s2 :: 's2 and X :: "'s1  's2  bool" and I1 :: "'s1  bool" and I2 :: "'s2  bool"
  and oracle1 :: "'s1  'a  ('b × 's1) spmf"
  and oracle2 :: "'s2  'a  ('b × 's2) spmf"
  assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
  and bad: "bad1 s1 = bad2 s2"
  and bisim: "s1 s2 x.  X s1 s2; x  outs_ℐ ; I1 s1; I2 s2   rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
  and bad_sticky1: "s2.  bad2 s2; I2 s2   callee_invariant_on oracle1 (λs1. bad1 s1  X_bad s1 s2) "
  and bad_sticky2: "s1.  bad1 s1; I1 s1   callee_invariant_on oracle2 (λs2. bad2 s2  X_bad s1 s2) "
  and lossless1: "s1 x.  bad1 s1; I1 s1; x  outs_ℐ    lossless_spmf (oracle1 s1 x)"
  and lossless2: "s2 x.  bad2 s2; I2 s2; x  outs_ℐ    lossless_spmf (oracle2 s2 x)"
  and lossless: "lossless_gpv  gpv"
  and WT_gpv: " ⊢g gpv "
  and I1: "callee_invariant_on oracle1 I1 "
  and I2: "callee_invariant_on oracle2 I2 "
  and s1: "I1 s1"
  and s2: "I2 s2"
  shows "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
  including lifting_syntax
proof -
  interpret I1: callee_invariant_on oracle1 I1  by(fact I1)
  interpret I2: callee_invariant_on oracle2 I2  by(fact I2)
  from s1 have nonempty1: "{s. I1 s}  {}" by auto
  { assume "(Rep1 :: 's1'  's1) Abs1. type_definition Rep1 Abs1 {s. I1 s}"
      and "(Rep2 :: 's2'  's2) Abs2. type_definition Rep2 Abs2 {s. I2 s}"
    then obtain Rep1 :: "'s1'  's1" and Abs1 and Rep2 :: "'s2'  's2" and Abs2
      where td1: "type_definition Rep1 Abs1 {s. I1 s}" and td2: "type_definition Rep2 Abs2 {s. I2 s}"
      by blast
    interpret td1: type_definition Rep1 Abs1 "{s. I1 s}" by(rule td1)
    interpret td2: type_definition Rep2 Abs2 "{s. I2 s}" by(rule td2)
    define cr1 where "cr1  λx y. x = Rep1 y"
    have [transfer_rule]: "bi_unique cr1" "right_total cr1" using td1 cr1_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr1 = I1" using type_definition_Domainp[OF td1 cr1_def] by simp
    define cr2 where "cr2  λx y. x = Rep2 y"
    have [transfer_rule]: "bi_unique cr2" "right_total cr2" using td2 cr2_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr2 = I2" using type_definition_Domainp[OF td2 cr2_def] by simp

    let ?C = "eq_onp (λout. out  outs_ℐ )"

    define oracle1' where "oracle1'  (Rep1 ---> id ---> map_spmf (map_prod id Abs1)) oracle1"
    have [transfer_rule]: "(cr1 ===> ?C ===> rel_spmf (rel_prod (=) cr1)) oracle1 oracle1'"
      by(auto simp add: oracle1'_def rel_fun_def cr1_def spmf_rel_map prod.rel_map td1.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td1.Rep[simplified] dest: I1.callee_invariant)
    define oracle2' where "oracle2'  (Rep2 ---> id ---> map_spmf (map_prod id Abs2)) oracle2"
    have [transfer_rule]: "(cr2 ===> ?C ===> rel_spmf (rel_prod (=) cr2)) oracle2 oracle2'"
      by(auto simp add: oracle2'_def rel_fun_def cr2_def spmf_rel_map prod.rel_map td2.Abs_inverse eq_onp_def intro!: rel_spmf_reflI intro: td2.Rep[simplified] dest: I2.callee_invariant)

    define s1' where "s1'  Abs1 s1"
    have [transfer_rule]: "cr1 s1 s1'" using s1 by(simp add: cr1_def s1'_def td1.Abs_inverse)
    define s2' where "s2'  Abs2 s2"
    have [transfer_rule]: "cr2 s2 s2'" using s2 by(simp add: cr2_def s2'_def td2.Abs_inverse)

    define bad1' where "bad1'  (Rep1 ---> id) bad1"
    have [transfer_rule]: "(cr1 ===> (=)) bad1 bad1'" by(simp add: rel_fun_def bad1'_def cr1_def)
    define bad2' where "bad2'  (Rep2 ---> id) bad2"
    have [transfer_rule]: "(cr2 ===> (=)) bad2 bad2'" by(simp add: rel_fun_def bad2'_def cr2_def)

    define X' where "X'  (Rep1 ---> Rep2 ---> id) X"
    have [transfer_rule]: "(cr1 ===> cr2 ===> (=)) X X'" by(simp add: rel_fun_def X'_def cr1_def cr2_def)
    define X_bad' where "X_bad'  (Rep1 ---> Rep2 ---> id) X_bad"
    have [transfer_rule]: "(cr1 ===> cr2 ===> (=)) X_bad X_bad'" by(simp add: rel_fun_def X_bad'_def cr1_def cr2_def)

    define gpv' where "gpv'  restrict_gpv  gpv"
    have [transfer_rule]: "rel_gpv (=) ?C gpv' gpv'"
      by(fold eq_onp_top_eq_eq)(auto simp add: gpv.rel_eq_onp eq_onp_same_args pred_gpv_def gpv'_def dest: in_outs'_restrict_gpvD)

    have "if bad2' s2' then X_bad' s1' s2' else X' s1' s2'" using * by transfer
    moreover have "bad1' s1'  bad2' s2'" using bad by transfer
    moreover have x: "?C x x" if "x  outs_ℐ " for x using that by(simp add: eq_onp_def)
    have "rel_spmf (λ(a, s1') (b, s2'). (bad1' s1'  bad2' s2')  (if bad2' s2' then X_bad' s1' s2' else a = b  X' s1' s2')) (oracle1' s1 x) (oracle2' s2 x)"
      if "X' s1 s2" and "x  outs_ℐ " for s1 s2 x using that(1) supply that(2)[THEN x, transfer_rule]
      by(transfer)(rule bisim[OF _ that(2)])
    moreover have [transfer_rule]: "rel_ℐ ?C (=)  " by(rule rel_ℐI)(auto simp add: set_relator_eq_onp eq_onp_same_args rel_set_eq dest: eq_onp_to_eq)
    have "callee_invariant_on oracle1' (λs1. bad1' s1  X_bad' s1 s2) " if "bad2' s2" for s2
      using that unfolding callee_invariant_on_alt_def apply(transfer)
      using bad_sticky1[unfolded callee_invariant_on_alt_def] by blast
    moreover have "callee_invariant_on oracle2' (λs2. bad2' s2  X_bad' s1 s2) " if "bad1' s1" for s1
      using that unfolding callee_invariant_on_alt_def apply(transfer)
      using bad_sticky2[unfolded callee_invariant_on_alt_def] by blast
    moreover have "lossless_spmf (oracle1' s1 x)" if "bad1' s1" "x  outs_ℐ " for s1 x
      using that supply that(2)[THEN x, transfer_rule] by transfer(rule lossless1)
    moreover have "lossless_spmf (oracle2' s2 x)" if "bad2' s2" "x  outs_ℐ " for s2 x
      using that supply that(2)[THEN x, transfer_rule] by transfer(rule lossless2)
    moreover have "lossless_gpv  gpv'" using WT_gpv lossless by(simp add: gpv'_def lossless_restrict_gpvI)
    moreover have " ⊢c oracle1' s1 " for s1 using I1.WT_callee by transfer
    moreover have " ⊢c oracle2' s2 " for s2 using I2.WT_callee by transfer
    moreover have " ⊢g gpv' " by(simp add: gpv'_def)
    ultimately have **: "rel_spmf (λ(a, s1') (b, s2'). bad1' s1' = bad2' s2'  (if bad2' s2' then X_bad' s1' s2' else a = b  X' s1' s2')) (exec_gpv oracle1' gpv' s1') (exec_gpv oracle2' gpv' s2')"
      by(rule exec_gpv_oracle_bisim_bad')
    have [transfer_rule]: "((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) oracle2 oracle2"
      "((=) ===> ?C ===> rel_spmf (rel_prod (=) (=))) oracle1 oracle1"
      by(simp_all add: rel_fun_def eq_onp_def prod.rel_eq)
    note [transfer_rule] = bi_unique_eq_onp bi_unique_eq
    from ** have "rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (exec_gpv oracle1 gpv' s1) (exec_gpv oracle2 gpv' s2)"
      by(transfer)
    also have "exec_gpv oracle1 gpv' s1 = exec_gpv oracle1 gpv s1"
      unfolding gpv'_def using WT_gpv s1 by(rule I1.exec_gpv_restrict_gpv_invariant)
    also have "exec_gpv oracle2 gpv' s2 = exec_gpv oracle2 gpv s2"
      unfolding gpv'_def using WT_gpv s2 by(rule I2.exec_gpv_restrict_gpv_invariant)
    finally have ?thesis . }
  from this[cancel_type_definition, OF nonempty1, cancel_type_definition] s2 show ?thesis by blast
qed

lemma exec_gpv_oracle_bisim_bad:
  assumes *: "if bad2 s2 then X_bad s1 s2 else X s1 s2"
  and bad: "bad1 s1 = bad2 s2"
  and bisim: "s1 s2 x. X s1 s2  rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (if bad2 s2' then X_bad s1' s2' else a = b  X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
  and bad_sticky1: "s2. bad2 s2  callee_invariant_on oracle1 (λs1. bad1 s1  X_bad s1 s2) "
  and bad_sticky2: "s1. bad1 s1  callee_invariant_on oracle2 (λs2. bad2 s2  X_bad s1 s2) "
  and lossless1: "s1 x. bad1 s1  lossless_spmf (oracle1 s1 x)"
  and lossless2: "s2 x. bad2 s2  lossless_spmf (oracle2 s2 x)"
  and lossless: "lossless_gpv  gpv"
  and WT_oracle1: "s1.  ⊢c oracle1 s1 "
  and WT_oracle2: "s2.  ⊢c oracle2 s2 "
  and WT_gpv: " ⊢g gpv "
  and R: "a s1 b s2.  bad1 s1 = bad2 s2; ¬ bad2 s2  a = b  X s1 s2; bad2 s2  X_bad s1 s2   R (a, s1) (b, s2)"
  shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using exec_gpv_oracle_bisim_bad'[OF * bad bisim bad_sticky1 bad_sticky2 lossless1 lossless2 lossless WT_oracle1 WT_oracle2 WT_gpv]
by(rule rel_spmf_mono)(auto intro: R)

lemma exec_gpv_oracle_bisim_bad_full:
  assumes "X s1 s2"
  and "bad1 s1 = bad2 s2"
  and "s1 s2 x. X s1 s2  rel_spmf (λ(a, s1') (b, s2'). bad1 s1' = bad2 s2'  (¬ bad2 s2'  a = b  X s1' s2')) (oracle1 s1 x) (oracle2 s2 x)"
  and "callee_invariant oracle1 bad1"
  and "callee_invariant oracle2 bad2"
  and "s1 x. bad1 s1  lossless_spmf (oracle1 s1 x)"
  and "s2 x. bad2 s2  lossless_spmf (oracle2 s2 x)"
  and "lossless_gpv ℐ_full gpv"
  and R: "a s1 b s2.  bad1 s1 = bad2 s2; ¬ bad2 s2  a = b  X s1 s2   R (a, s1) (b, s2)"
  shows "rel_spmf R (exec_gpv oracle1 gpv s1) (exec_gpv oracle2 gpv s2)"
using assms
by(intro exec_gpv_oracle_bisim_bad[of bad2 s2 "λ_ _. True" s1 X bad1 oracle1 oracle2 ℐ_full gpv R])(auto intro: rel_spmf_mono)

lemma max_enn2ereal: "max (enn2ereal x) (enn2ereal y) = enn2ereal (max x y)"
including ennreal.lifting unfolding max_def by transfer simp

lemma identical_until_bad:
  assumes bad_eq: "map_spmf bad p = map_spmf bad q"
  and not_bad: "measure (measure_spmf (map_spmf (λx. (f x, bad x)) p)) (A × {False}) = measure (measure_spmf (map_spmf (λx. (f x, bad x)) q)) (A × {False})"
  shows "¦measure (measure_spmf (map_spmf f p)) A - measure (measure_spmf (map_spmf f q)) A¦  spmf (map_spmf bad p) True"
proof -
  have "¦enn2ereal (measure (measure_spmf (map_spmf f p)) A) - enn2ereal (measure (measure_spmf (map_spmf f q)) A)¦ = 
    ¦enn2ereal (+ x. indicator A (f x) measure_spmf p) - enn2ereal (+ x. indicator A (f x) measure_spmf q)¦"
    unfolding measure_spmf.emeasure_eq_measure[symmetric]
    by(simp add: nn_integral_indicator[symmetric] indicator_vimage[abs_def] o_def)
  also have " =
    ¦enn2ereal (+ x. indicator (A × {False}) (f x, bad x) + indicator (A × {True}) (f x, bad x) measure_spmf p) -
     enn2ereal (+ x. indicator (A × {False}) (f x, bad x) + indicator (A × {True}) (f x, bad x) measure_spmf q)¦"
    by(intro arg_cong[where f=abs] arg_cong2[where f="(-)"] arg_cong[where f=enn2ereal] nn_integral_cong)(simp_all split: split_indicator)
  also have " = 
    ¦enn2ereal (emeasure (measure_spmf (map_spmf (λx. (f x, bad x)) p)) (A × {False}) + (+ x. indicator (A × {True}) (f x, bad x) measure_spmf p)) -
     enn2ereal (emeasure (measure_spmf (map_spmf (λx. (f x, bad x)) q)) (A × {False}) + (+ x. indicator (A × {True}) (f x, bad x) measure_spmf q))¦"
    by(subst (1 2) nn_integral_add)(simp_all add: indicator_vimage[abs_def] o_def nn_integral_indicator[symmetric])
  also have " = ¦enn2ereal (+ x. indicator (A × {True}) (f x, bad x) measure_spmf p) - enn2ereal (+ x. indicator (A × {True}) (f x, bad x) measure_spmf q)¦"
    (is "_ = ¦?x - ?y¦")
    by(simp add: measure_spmf.emeasure_eq_measure not_bad plus_ennreal.rep_eq ereal_diff_add_eq_diff_diff_swap ereal_diff_add_assoc2 ereal_add_uminus_conv_diff)
  also have "  max ?x ?y"
  proof(rule ereal_abs_leI)
    have "?x - ?y  ?x - 0" by(rule ereal_minus_mono)(simp_all)
    also have "  max ?x ?y" by simp
    finally show "?x - ?y  " .

    have "- (?x - ?y) = ?y - ?x"
      by(rule ereal_minus_diff_eq)(simp_all add: measure_spmf.nn_integral_indicator_neq_top)
    also have "  ?y - 0" by(rule ereal_minus_mono)(simp_all)
    also have "  max ?x ?y" by simp
    finally show "- (?x - ?y)  " .
  qed
  also have "  enn2ereal (max (+ x. indicator {True} (bad x) measure_spmf p) (+ x. indicator {True} (bad x) measure_spmf q))"
    unfolding max_enn2ereal less_eq_ennreal.rep_eq[symmetric]
    by(intro max.mono nn_integral_mono)(simp_all split: split_indicator)
  also have " = enn2ereal (spmf (map_spmf bad p) True)"
    using arg_cong2[where f=spmf, OF bad_eq refl, of True, THEN arg_cong[where f=ennreal]]
    unfolding ennreal_spmf_map_conv_nn_integral indicator_vimage[abs_def] by simp
  finally show ?thesis by simp
qed

lemma (in callee_invariant_on) exec_gpv_bind_materialize:
  fixes f :: "'s  'r spmf"
  and g :: "'x × 's  'r  'y spmf"
  and s :: "'s"
  defines "exec_gpv2  exec_gpv"
  assumes cond: "s x y s'.  (y, s')  set_spmf (callee s x); I s   f s = f s'"
  and: " = ℐ_full" (* TODO: generalize *)
  shows "bind_spmf (exec_gpv callee gpv s) (λas. bind_spmf (f (snd as)) (g as)) =
    exec_gpv2 (λ(r, s) x. bind_spmf (callee s x) (λ(y, s'). if I s'  r = None then map_spmf (λr. (y, (Some r, s'))) (f s') else return_spmf (y, (r, s')))) gpv (None, s)
     (λ(a, r, s). case r of None  bind_spmf (f s) (g (a, s)) | Some r'  g (a, s) r')"
    (is "?lhs = ?rhs" is "_ = bind_spmf (exec_gpv2 ?callee2 _ _) _")
proof -
  define exec_gpv1 :: "('a, 'b, 's option × 's) callee  ('x, 'a, 'b) gpv  _"
    where [simp]: "exec_gpv1 = exec_gpv"
  let ?X = "λs (ss, s'). s = s'"
  let ?callee = "λ(ss, s) x. map_spmf (λ(y, s'). (y, if I s'  ss = None then Some s' else ss, s')) (callee s x)"
  let ?track = "exec_gpv1 ?callee gpv (None, s)"
  have "rel_spmf (rel_prod (=) ?X) (exec_gpv callee gpv s) ?track" unfolding exec_gpv1_def
    by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map intro!: rel_spmf_reflI)
  hence "exec_gpv callee gpv s = map_spmf (λ(a, ss, s). (a, s)) ?track"
    by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
  hence "?lhs = bind_spmf ?track (λ(a, s'', s'). bind_spmf (f s') (g (a, s')))"
    by(simp add: bind_map_spmf o_def split_def)
  also let ?inv = "λ(ss, s). case ss of None  True | Some s'  f s = f s'  I s'  I s"
  interpret inv: callee_invariant_on "?callee" "?inv" 
    by unfold_locales(auto 4 4 split: option.split if_split_asm dest: cond callee_invariant simp add:)
  have "bind_spmf ?track (λ(a, s'', s'). bind_spmf (f s') (g (a, s'))) =
    bind_spmf ?track (λ(a, ss', s'). bind_spmf (f (case ss' of None  s' | Some s''  s'')) (g (a, s')))"
    (is "_ = ?rhs'")
    by(rule bind_spmf_cong[OF refl])(auto dest!: inv.exec_gpv_invariant split: option.split_asm simp add:)
  also
  have track_Some: "exec_gpv ?callee gpv (Some ss, s) = map_spmf (λ(a, s). (a, Some ss, s)) (exec_gpv callee gpv s)"
    for s ss :: 's and gpv :: "('x, 'a, 'b) gpv"
  proof -
    let ?X = "λ(ss', s') s. s = s'  ss' = Some ss"
    have "rel_spmf (rel_prod (=) ?X) (exec_gpv ?callee gpv (Some ss, s)) (exec_gpv callee gpv s)"
      by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map intro!: rel_spmf_reflI)
    thus ?thesis by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
  qed
  have sample_Some: "exec_gpv ?callee2 gpv (Some r, s) = map_spmf (λ(a, s). (a, Some r, s)) (exec_gpv callee gpv s)" 
    for s :: 's and r :: 'r and gpv :: "('x, 'a, 'b) gpv"
  proof -
    let ?X = "λ(r', s') s. s' = s  r' = Some r"
    have "rel_spmf (rel_prod (=) ?X) (exec_gpv ?callee2 gpv (Some r, s)) (exec_gpv callee gpv s)"
      by(rule exec_gpv_oracle_bisim[where X="?X"])(auto simp add: spmf_rel_map map_spmf_conv_bind_spmf[symmetric] split_def intro!: rel_spmf_reflI)
    then show ?thesis by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map elim: rel_spmf_mono)
  qed
  have "?rhs' = ?rhs"
    ― ‹Actually, parallel fixpoint induction should be used here, but then we cannot use the
      facts @{thm [source] track_Some} and @{thm [source] sample_Some} because fixpoint induction
      replaces @{const exec_gpv} with approximations. So we do two separate fixpoint inductions
      instead and jump from the approximation to the fixpoint when the state has been found.›
  proof(rule spmf.leq_antisym)
    show "ord_spmf (=) ?rhs' ?rhs" unfolding exec_gpv1_def
    proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct_strong)
      case adm show ?case by simp
      case bottom show ?case by simp
      case (step exec_gpv')
      show ?case unfolding exec_gpv2_def
        apply(rewrite in "ord_spmf _ _ " exec_gpv.simps)
        apply(clarsimp split: generat.split simp add: bind_map_spmf intro!: ord_spmf_bind_reflI split del: if_split)
        subgoal for out rpv ret s'
          apply(cases "I s'")
          subgoal
            apply simp
            apply(rule spmf.leq_trans)
             apply(rule ord_spmf_bindI[OF step.hyps])
             apply hypsubst
             apply(rule spmf.leq_refl)
            apply(simp add: track_Some sample_Some bind_map_spmf o_def)
            apply(subst bind_commute_spmf)
            apply(simp add: split_def)
            done
          subgoal
            apply simp
            apply(rule step.IH[THEN spmf.leq_trans])
            apply(simp add: split_def exec_gpv2_def)
            done
          done
        done
    qed
    show "ord_spmf (=) ?rhs ?rhs'" unfolding exec_gpv2_def
    proof(induction arbitrary: gpv s rule: exec_gpv_fixp_induct_strong)
      case adm show ?case by simp
      case bottom show ?case by simp
      case (step exec_gpv')
      show ?case unfolding exec_gpv1_def
        apply(rewrite in "ord_spmf _ _ " exec_gpv.simps)
        apply(clarsimp split: generat.split simp add: bind_map_spmf intro!: ord_spmf_bind_reflI split del: if_split)
        subgoal for out rpv ret s'
          apply(cases "I s'")
          subgoal
            apply(simp add: bind_map_spmf o_def)
            apply(rule spmf.leq_trans)
             apply(rule ord_spmf_bind_reflI)
             apply(rule ord_spmf_bindI)
              apply(rule step.hyps)
             apply hypsubst
             apply(rule spmf.leq_refl)
            apply(simp add: track_Some sample_Some bind_map_spmf o_def)
            apply(subst bind_commute_spmf)
            apply(simp add: split_def)
            done
          subgoal
            apply simp
            apply(rule step.IH[THEN spmf.leq_trans])
            apply(simp add: split_def exec_gpv2_def)
            done
          done
        done
    qed
  qed
  finally show ?thesis .
qed

primcorec gpv_stop :: "('a, 'c, 'r) gpv  ('a option, 'c, 'r option) gpv"
where
  "the_gpv (gpv_stop gpv) = 
   map_spmf (map_generat Some id (λrpv input. case input of None  Done None | Some input'  gpv_stop (rpv input'))) 
     (the_gpv gpv)"

lemma gpv_stop_Done [simp]: "gpv_stop (Done x) = Done (Some x)"
by(rule gpv.expand) simp

lemma gpv_stop_Fail [simp]: "gpv_stop Fail = Fail"
by(rule gpv.expand) simp

lemma gpv_stop_Pause [simp]: "gpv_stop (Pause out rpv) = Pause out (λinput. case input of None  Done None | Some input'  gpv_stop (rpv input'))"
by(rule gpv.expand) simp

lemma gpv_stop_lift_spmf [simp]: "gpv_stop (lift_spmf p) = lift_spmf (map_spmf Some p)"
by(rule gpv.expand)(simp add: spmf.map_comp o_def)

lemma gpv_stop_bind [simp]:
  "gpv_stop (bind_gpv gpv f) = bind_gpv (gpv_stop gpv) (λx. case x of None  Done None | Some x'  gpv_stop (f x'))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(auto 4 3 simp add: spmf_rel_map map_spmf_bind_spmf o_def bind_map_spmf bind_gpv.sel generat.rel_map simp del: bind_gpv_sel' intro!: rel_spmf_bind_reflI generat.rel_refl_strong rel_spmf_reflI rel_funI split!: generat.split option.split)
done

context includes lifting_syntax begin

lemma gpv_stop_parametric':
  notes [transfer_rule] = the_gpv_parametric' the_gpv_parametric' Done_parametric' corec_gpv_parametric'
  shows "(rel_gpv'' A C R ===> rel_gpv'' (rel_option A) C (rel_option R)) gpv_stop gpv_stop"
unfolding gpv_stop_def by transfer_prover

lemma gpv_stop_parametric [transfer_rule]:
  shows "(rel_gpv A C ===> rel_gpv (rel_option A) C) gpv_stop gpv_stop"
unfolding gpv_stop_def by transfer_prover

lemma gpv_stop_transfer:
  "(rel_gpv'' A B C ===> rel_gpv'' (pcr_Some A) B (pcr_Some C)) (λx. x) gpv_stop"
apply(rule rel_funI)
subgoal for gpv gpv'
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(auto simp add: spmf_rel_map generat.rel_map rel_fun_def elim!: pcr_SomeE generat.rel_mono_strong rel_spmf_mono)
  done
done

end
  
lemma gpv_stop_map' [simp]:
  "gpv_stop (map_gpv' f g h gpv) = map_gpv' (map_option f) g (map_option h) (gpv_stop gpv)"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl_strong split!: option.split)
done

lemma interaction_bound_gpv_stop [simp]:
  "interaction_bound consider (gpv_stop gpv) = interaction_bound consider gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_strong_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions interaction_bound.mono interaction_bound.mono interaction_bound_def interaction_bound_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
next
  case (step interaction_bound' interaction_bound'')
  have "(SUP x. interaction_bound' (case x of None  Done None | Some input  gpv_stop (c input))) =
        (SUP input. interaction_bound'' (c input))" (is "?lhs = ?rhs" is "(SUP x. ?f x) = _")
    if "IO out c  set_spmf (the_gpv gpv)" for out c
  proof -
    have "?lhs = sup (interaction_bound' (Done None)) (x. ?f (Some x))"
      by (simp add: UNIV_option_conv image_comp)
    also have "interaction_bound' (Done None) = 0" using step.hyps(1)[of "Done None"] by simp
    also have "(x. ?f (Some x)) = ?rhs" by (simp add: step.IH)
    finally show ?thesis by (simp add: bot_enat_def [symmetric])
  qed
  then show ?case
    by (auto simp add: case_map_generat o_def image_comp cong del: generat.case_cong_weak if_weak_cong intro!: SUP_cong split: generat.split)
qed
  
abbreviation exec_gpv_stop :: "('s  'c  ('r option × 's) spmf)  ('a, 'c, 'r) gpv  's  ('a option × 's) spmf"
where "exec_gpv_stop callee gpv  exec_gpv callee (gpv_stop gpv)"

abbreviation inline_stop :: "('s  'c  ('r option × 's, 'c', 'r') gpv)  ('a, 'c, 'r) gpv  's  ('a option × 's, 'c', 'r') gpv"
where "inline_stop callee gpv  inline callee (gpv_stop gpv)"

context
  fixes joint_oracle :: "'s1  's2  'c  (('r option × 's1) option × ('r option × 's2) option) pmf"
  and callee1 :: "'s1  'c  ('r option × 's1) spmf"
  notes [[function_internals]]
begin

partial_function (spmf) exec_until_stop :: "('a option, 'c, 'r) gpv  's1  's2  bool  ('a option × 's1 × 's2) spmf"
where
  "exec_until_stop gpv s1 s2 b =
  (if b then 
     bind_spmf (the_gpv gpv) (λgenerat. case generat of
       Pure x  return_spmf (x, s1, s2)
     | IO out rpv  bind_pmf (joint_oracle s1 s2 out) (λ(a, b).
         case a of None  return_pmf None
         | Some (r1, s1')  (case b of None  undefined | Some (r2, s2') 
           (case (r1, r2) of (None, None)  exec_until_stop (Done None) s1' s2' True
             | (Some r1', Some r2')  exec_until_stop (rpv r1') s1' s2' True
             | (None, Some r2')  exec_until_stop (Done None) s1' s2' True
             | (Some r1', None)  exec_until_stop (rpv r1') s1' s2' False))))
   else
     bind_spmf (the_gpv gpv) (λgenerat. case generat of
       Pure x  return_spmf (None, s1, s2)
     | IO out rpv  bind_spmf (callee1 s1 out) (λ(r1, s1').
         case r1 of None  exec_until_stop (Done None) s1' s2 False
           | Some r1'  exec_until_stop (rpv r1') s1' s2 False)))"

end

lemma ord_spmf_exec_gpv_stop: (* TODO: generalize ord_spmf to support different type variables *)
  fixes callee1 :: "('c, 'r option, 's) callee"
  and callee2 :: "('c, 'r option, 's) callee"
  and S :: "'s  's  bool"
  and gpv :: "('a, 'c, 'r) gpv"
  assumes bisim:
    "s1 s2 x.  S s1 s2; ¬ stop s2   
    ord_spmf (λ(r1, s1') (r2, s2'). le_option r2 r1  S s1' s2'  (r2 = None  r1  None  stop s2'))
      (callee1 s1 x) (callee2 s2 x)"
  and init: "S s1 s2"
  and go: "¬ stop s2"
  and sticking: "s1 s2 x y s1'.  (y, s1')  set_spmf (callee1 s1 x); S s1 s2; stop s2   S s1' s2"
  shows "ord_spmf (rel_prod (ord_option )¯¯ S) (exec_gpv_stop callee1 gpv s1) (exec_gpv_stop callee2 gpv s2)"
proof -
  let ?R = "λ(r1, s1') (r2, s2'). le_option r2 r1  S s1' s2'  (r2 = None  r1  None  stop s2')"
  obtain joint :: "'s  's  'c  (('r option × 's) option × ('r option × 's) option) pmf"
    where j1: "map_pmf fst (joint s1 s2 x) = callee1 s1 x"
    and j2: "map_pmf snd (joint s1 s2 x) = callee2 s2 x"
    and rel [rule_format, rotated -1]: "(a, b)  set_pmf (joint s1 s2 x). ord_option ?R a b"
    if "S s1 s2" "¬ stop s2" for x s1 s2 using bisim
    apply atomize_elim 
    apply(subst (asm) rel_pmf.simps)
    apply(unfold rel_spmf_simps all_conj_distrib[symmetric] all_simps(6) imp_conjR[symmetric])
    apply(subst all_comm)
    apply(subst (2) all_comm)
    apply(subst choice_iff[symmetric] ex_simps(6))+
    apply fastforce
    done
  note [simp del] = top_apply conversep_iff id_apply
  have "¬ stop s2  rel_spmf (rel_prod (ord_option )¯¯ S) (exec_gpv_stop callee1 gpv s1) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 True))"
    and "rel_spmf (rel_prod (ord_option )¯¯ S) (exec_gpv callee1 (Done None :: ('a option, 'c, 'r option) gpv) s1) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (Done None :: ('a option, 'c, 'r) gpv) s1 s2 b))"
    and "stop s2  rel_spmf (rel_prod (ord_option )¯¯ S) (exec_gpv_stop callee1 gpv s1) (map_spmf (λ(x, s1, y). (x, y)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 False))"
    for b using init
  proof(induction arbitrary: gpv s1 s2 b rule: parallel_fixp_induct_2_4[OF partial_function_definitions_spmf partial_function_definitions_spmf exec_gpv.mono exec_until_stop.mono exec_gpv_def exec_until_stop_def, unfolded lub_spmf_empty, case_names adm bottom step])
    case adm show ?case by simp
    { case bottom case 1 show ?case by simp }
    { case bottom case 2 show ?case by simp }
    { case bottom case 3 show ?case by simp }
  next
    case (step exec_gpv' exec_until_stop') case step: 1
    show ?case using step.prems
      apply(rewrite gpv_stop.sel)
      apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
      apply(rule rel_spmf_bind_reflI)
      apply(clarsimp split!: generat.split)
      apply(rewrite j1[symmetric], assumption+)
      apply(rewrite bind_spmf_def)
      apply(auto 4 3 split!: option.split dest: rel intro: step.IH intro!: rel_pmf_bind_reflI simp add: map_bind_pmf bind_map_pmf)
      done
  next
    case step case 2
    then show ?case by(simp add: conversep_iff)
  next
    case (step exec_gpv' exec_until_stop') case step: 3
    show ?case using step.prems
      apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
      apply(rule rel_spmf_bind_reflI)
      apply(clarsimp simp add: map_spmf_bind_spmf split!: generat.split)
      apply(rule rel_spmf_bind_reflI)
      apply clarsimp
      apply(drule (2) sticking)
      apply(auto split!: option.split intro: step.IH)
      done
  qed
  note this(1)[OF go]
  also
  have "¬ stop s2  ord_spmf (=) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 True)) (exec_gpv_stop callee2 gpv s2)"
    and "ord_spmf (=) (map_spmf (λ(x, s1, y). (x, y)) (exec_until_stop joint callee1 (Done None :: ('a option, 'c, 'r) gpv) s1 s2 b)) (return_spmf (None, s2))"
    and "stop s2  ord_spmf (=) (map_spmf (λ(x, s1, s2). (x, s2)) (exec_until_stop joint callee1 (map_gpv Some id gpv) s1 s2 False)) (return_spmf (None, s2))"
    for b using init
  proof(induction arbitrary: gpv s1 s2 b rule: exec_until_stop.fixp_induct[case_names adm bottom step])
    case adm show ?case by simp
    { case bottom case 1 show ?case by simp }
    { case bottom case 2 show ?case by simp }
    { case bottom case 3 show ?case by simp }
  next
    case (step exec_until_stop') case step: 1
    show ?case using step.prems
      using [[show_variants]]
      apply(rewrite exec_gpv.simps)
      apply(simp add: map_spmf_bind_spmf bind_map_spmf gpv.map_sel)
      apply(rule ord_spmf_bind_reflI)
      apply(clarsimp split!: generat.split simp add: map_bind_pmf bind_spmf_def)
      apply(rewrite j2[symmetric], assumption+)
      apply(auto 4 3 split!: option.split dest: rel intro: step.IH intro!: rel_pmf_bind_reflI simp add: bind_map_pmf)
      done
  next
    case step case 2 thus ?case by simp
  next
    case (step exec_until_stop') case 3
    thus ?case
      apply(simp add: map_spmf_bind_spmf o_def)
      apply(rule ord_spmf_bind_spmfI1)
      apply(clarsimp split!: generat.split simp add: map_spmf_bind_spmf o_def gpv.map_sel)
      apply(rule ord_spmf_bind_spmfI1)
      apply clarsimp
      apply(drule (2) sticking)
      apply(clarsimp split!: option.split simp add: step.IH)
      done
  qed
  note this(1)[OF go]
  finally show ?thesis by(rule rel_pmf_mono)(auto elim!: option.rel_cases)
qed

end

Theory GPV_Applicative

theory GPV_Applicative imports
  Generative_Probabilistic_Value
  SPMF_Applicative
begin

subsection ‹Applicative instance for @{typ "(_, 'out, 'in) gpv"}

definition ap_gpv :: "('a  'b, 'out, 'in) gpv  ('a, 'out, 'in) gpv  ('b, 'out, 'in) gpv"
where "ap_gpv f x = bind_gpv f (λf'. bind_gpv x (λx'. Done (f' x')))"

adhoc_overloading Applicative.ap ap_gpv

abbreviation (input) pure_gpv :: "'a  ('a, 'out, 'in) gpv"
where "pure_gpv  Done"

context includes applicative_syntax begin

lemma ap_gpv_id: "pure_gpv (λx. x)  x = x"
by(simp add: ap_gpv_def)

lemma ap_gpv_comp: "pure_gpv (∘)  u  v  w = u  (v  w)"
by(simp add: ap_gpv_def bind_gpv_assoc)

lemma ap_gpv_homo: "pure_gpv f  pure_gpv x = pure_gpv (f x)"
by(simp add: ap_gpv_def)

lemma ap_gpv_interchange: "u  pure_gpv x = pure_gpv (λf. f x)  u"
by(simp add: ap_gpv_def)

applicative gpv
for
  pure: pure_gpv
  ap: ap_gpv
by(rule ap_gpv_id ap_gpv_comp[unfolded o_def[abs_def]] ap_gpv_homo ap_gpv_interchange)+

lemma map_conv_ap_gpv: "map_gpv f (λx. x) gpv = pure_gpv f  gpv"
by(simp add: ap_gpv_def map_gpv_conv_bind)

lemma exec_gpv_ap:
  "exec_gpv callee (f  x) σ = 
   exec_gpv callee f σ  (λ(f', σ'). pure_spmf (λ(x', σ''). (f' x', σ''))  exec_gpv callee x σ')"
by(simp add: ap_gpv_def exec_gpv_bind ap_spmf_conv_bind split_def)

lemma exec_gpv_ap_pure [simp]:
  "exec_gpv callee (pure_gpv f  x) σ = pure_spmf (apfst f)  exec_gpv callee x σ"
by(simp add: exec_gpv_ap apfst_def map_prod_def)

end

end

Theory Cyclic_Group

(* Title: Cyclic_Group.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Cyclic groups›

theory Cyclic_Group imports
  "HOL-Algebra.Coset"
begin

record 'a cyclic_group = "'a monoid" + 
  generator :: 'a ("gı")

locale cyclic_group = group G
  for G :: "('a, 'b) cyclic_group_scheme" (structure)
  +
  assumes generator_closed [intro, simp]: "generator G  carrier G"
  and generator: "carrier G  range (λn :: nat. generator G [^]G n)"
begin

lemma generatorE [elim?]:
  assumes "x  carrier G"
  obtains n :: nat where "x = generator G [^] n"
using generator assms by auto

lemma inj_on_generator: "inj_on (([^]) g) {..<order G}"
proof(rule inj_onI)
  fix n m
  assume "n  {..<order G}" "m  {..<order G}"
  hence n: "n < order G" and m: "m < order G" by simp_all
  moreover
  assume "g [^] n = g [^] m"
  ultimately show "n = m"
  proof(induction n m rule: linorder_wlog)
    case sym thus ?case by simp
  next
    case (le n m)
    let ?d = "m - n"
    have "g [^] (int m - int n) = g [^] int m  inv (g [^] int n)"
      by(simp add: int_pow_diff)
    also have "g [^] int m = g [^] int n" by(simp add: le.prems int_pow_int)
    also have "  inv (g [^] (int n)) = 𝟭" by simp
    finally have "g [^] ?d = 𝟭" using le.hyps by(simp add: of_nat_diff[symmetric] int_pow_int)
    { assume "n < m"
      have "carrier G  (λn. g [^] n) ` {..<?d}"
      proof
        fix x
        assume "x  carrier G"
        then obtain k :: nat where "x = g [^] k" ..
        also have " = (g [^] ?d) [^] (k div ?d)  g [^] (k mod ?d)"
          by(simp add: nat_pow_pow nat_pow_mult div_mult_mod_eq)
        also have " = g [^] (k mod ?d)"
          using g [^] ?d = 𝟭 by simp
        finally show "x  (λn. g [^] n) ` {..<?d}" using n < m by auto
      qed
      hence "order G  card ((λn. g [^] n) ` {..<?d})"
        by(simp add: order_def card_mono)
      also have "  card {..<?d}" by(rule card_image_le) simp
      also have " < order G" using m < order G by simp
      finally have False by simp }
    with n  m show "n = m" by(auto simp add: order.order_iff_strict)
  qed
qed

lemma finite_carrier: "finite (carrier G)" (* contributed by Dominique Unruh *)
proof -
  from generator obtain n :: nat where "g [^] n = inv g"
    by(metis generatorE generator_closed inv_closed)
  then have g1: "g [^] (Suc n) = 𝟭"
    by auto
  have mod: "g [^] m = g [^] (m mod Suc n)" for m
  proof -
    obtain k where "m mod Suc n + Suc n * k = m"
      by (metis mod_less_eq_dividend mod_mod_trivial nat_mod_eq_lemma)
    then have "g [^] m = g [^] (m mod Suc n + Suc n * k)" by simp
    also have " = g [^] (m mod Suc n)"
      unfolding nat_pow_mult[symmetric, OF generator_closed] nat_pow_pow[symmetric, OF generator_closed] g1
      by simp
    finally show ?thesis .
  qed
  have "g [^] x  ([^]) g ` {..<Suc n}" for x :: nat by (subst mod) auto
  then have "range (([^]) g :: nat  _)  (([^]) g) ` {..<Suc n}" by auto
  then have "finite (range (([^]) g :: nat  _))" by(rule finite_surj[rotated]) simp
  with generator show ?thesis by(rule finite_subset)
qed

lemma carrier_conv_generator: "carrier G = (λn. g [^] n) ` {..<order G}"
proof -
  have "(λn. g [^] n) ` {..<order G}  carrier G" by auto
  moreover have "card ((λn. g [^] n) ` {..<order G})  order G"
    using inj_on_generator by(simp add: card_image)
  ultimately show ?thesis using finite_carrier 
    unfolding order_def by(rule card_seteq[symmetric, rotated])
qed

lemma bij_betw_generator_carrier:
  "bij_betw (λn :: nat. g [^] n) {..<order G} (carrier G)"
by(simp add: bij_betw_def inj_on_generator carrier_conv_generator)

lemma order_gt_0: "order G > 0"
  using order_gt_0_iff_finite by(simp add: finite_carrier)

end

lemma (in monoid) order_in_range_Suc: "order G  range Suc  finite (carrier G)"
by(cases "order G")(auto simp add: order_def carrier_not_empty intro: card_ge_0_finite)

end

Theory Cyclic_Group_SPMF

(* Title: Cyclic_Group_SPMF.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory Cyclic_Group_SPMF imports
  Cyclic_Group
  "HOL-Probability.SPMF"
begin

definition sample_uniform :: "nat  nat spmf"
where "sample_uniform n = spmf_of_set {..<n}"

lemma spmf_sample_uniform: "spmf (sample_uniform n) x = indicator {..<n} x / n"
by(simp add: sample_uniform_def spmf_of_set)

lemma weight_sample_uniform: "weight_spmf (sample_uniform n) = indicator (range Suc) n"
by(auto simp add: sample_uniform_def weight_spmf_of_set split: split_indicator elim: lessE)

lemma weight_sample_uniform_0 [simp]: "weight_spmf (sample_uniform 0) = 0"
by(auto simp add: weight_sample_uniform indicator_def)

lemma weight_sample_uniform_gt_0 [simp]: "0 < n  weight_spmf (sample_uniform n) = 1"
by(auto simp add: weight_sample_uniform indicator_def gr0_conv_Suc)

lemma lossless_sample_uniform [simp]: "lossless_spmf (sample_uniform n)  0 < n"
by(auto simp add: lossless_spmf_def intro: ccontr)

lemma set_spmf_sample_uniform [simp]: "0 < n  set_spmf (sample_uniform n) = {..<n}"
by(simp add: sample_uniform_def)

lemma (in cyclic_group) sample_uniform_one_time_pad:
  assumes [simp]: "c  carrier G"
  shows
  "map_spmf (λx. g [^] x  c) (sample_uniform (order G)) = 
   map_spmf (λx. g [^] x) (sample_uniform (order G))"
   (is "?lhs = ?rhs")
proof(cases "finite (carrier G)")
  case False
  thus ?thesis by(simp add: order_def sample_uniform_def)
next
  case True
  have "?lhs = map_spmf (λx. x  c) ?rhs"
    by(simp add: pmf.map_comp o_def option.map_comp)
  also have rhs: "?rhs = spmf_of_set (carrier G)"
    using True by(simp add: carrier_conv_generator inj_on_generator sample_uniform_def)
  also have "map_spmf (λx. x  c)  = spmf_of_set ((λx. x  c) ` carrier G)"
    by(simp add: inj_on_multc)
  also have "(λx. x  c) ` carrier G = carrier G"
    using True by(rule endo_inj_surj)(auto simp add: inj_on_multc)
  finally show ?thesis using rhs by simp
qed

end